diff --git a/cli/BUILD b/cli/BUILD index 2eebdabb1..053df52d4 100644 --- a/cli/BUILD +++ b/cli/BUILD @@ -65,7 +65,13 @@ ts_test_suite( "index_project_test.ts", "index_compile_test.ts", "index_run_e2e_test.ts", - "util_test.ts" + "tests/jit/index_jit_main_test.ts", + "tests/jit/index_jit_advanced_test.ts", + "tests/jit/index_jit_dependency_test.ts", + "tests/jit/index_jit_runtime_test.ts", + "util_test.ts", + "tests/jit/jit_build_test.ts", + "tests/jit/jit_run_test.ts", ], data = [ ":node_modules", diff --git a/cli/api/BUILD b/cli/api/BUILD index 352f3ef06..0b2e2f0c4 100644 --- a/cli/api/BUILD +++ b/cli/api/BUILD @@ -58,7 +58,8 @@ node_modules( ts_test_suite( name = "tests", srcs = [ - "utils_test.ts", + "commands/jit/rpc_test.ts", + "dbadapters/bigquery_test.ts", ], data = [ ":node_modules", @@ -68,10 +69,11 @@ ts_test_suite( "@nodejs//:npm", ], deps = [ - "//cli/api", + ":api", "//core", "//protos:ts", "//testing", + "@npm//@google-cloud/bigquery", "@npm//@types/chai", "@npm//@types/fs-extra", "@npm//@types/js-yaml", diff --git a/cli/api/commands/base_worker.ts b/cli/api/commands/base_worker.ts new file mode 100644 index 000000000..f1c420722 --- /dev/null +++ b/cli/api/commands/base_worker.ts @@ -0,0 +1,74 @@ +import { ChildProcess, fork } from "child_process"; + +export abstract class BaseWorker { + protected constructor(private readonly loaderPath: string) {} + + protected async runWorker( + timeoutMillis: number, + onBoot: (child: ChildProcess) => void, + onMessage: (message: TMessage, child: ChildProcess, resolve: (res: TResponse) => void, reject: (err: Error) => void) => void + ): Promise { + const forkScript = this.resolveScript(); + const child = fork(forkScript, [], { + stdio: [0, 1, 2, "ipc", "pipe"] + }); + + return new Promise((resolve, reject) => { + let completed = false; + let booted = false; + + const terminate = (fn: () => void) => { + if (completed) { + return; + } + completed = true; + clearTimeout(timeout); + child.kill("SIGKILL"); + fn(); + }; + + const timeout = setTimeout(() => { + terminate(() => + reject(new Error(`Worker timed out after ${timeoutMillis / 1000} seconds`)) + ); + }, timeoutMillis); + + child.on("message", (message: any) => { + if (message.type === "worker_booted") { + if (!booted) { + booted = true; + onBoot(child); + } + return; + } + onMessage(message, child, (res) => terminate(() => resolve(res)), (err) => terminate(() => reject(err))); + }); + + child.on("error", err => { + terminate(() => reject(err)); + }); + + child.on("exit", (code, signal) => { + if (!completed) { + const errorMsg = + code !== 0 && code !== null + ? `Worker exited with code ${code} and signal ${signal}` + : "Worker exited without sending a response message"; + terminate(() => reject(new Error(errorMsg))); + } + }); + }); + } + + private resolveScript() { + const pathsToTry = ["./worker_bundle.js", this.loaderPath]; + for (const p of pathsToTry) { + try { + return require.resolve(p); + } catch (e) { + // Continue to next path. + } + } + throw new Error(`Could not resolve worker script. Tried: ${pathsToTry.join(", ")}`); + } +} diff --git a/cli/api/commands/build.ts b/cli/api/commands/build.ts index 99538381a..fb22365fb 100644 --- a/cli/api/commands/build.ts +++ b/cli/api/commands/build.ts @@ -69,7 +69,8 @@ export class Builder { runConfig: this.runConfig, warehouseState: this.warehouseState, declarationTargets: this.prunedGraph.declarations.map(declaration => declaration.target), - actions + actions, + jitData: this.prunedGraph.jitData }); } @@ -82,9 +83,7 @@ export class Builder { ...this.toPartialExecutionAction(table), type: "table", tableType: utils.tableTypeEnumToString(table.enumType), - tasks: table.disabled - ? [] - : this.executionSql.publishTasks(table, runConfig, tableMetadata).build(), + tasks: this.executionSql.createTableTasks(table, runConfig, tableMetadata), hermeticity: table.hermeticity || dataform.ActionHermeticity.HERMETIC }; } @@ -93,9 +92,7 @@ export class Builder { return { ...this.toPartialExecutionAction(operation), type: "operation", - tasks: operation.disabled - ? [] - : operation.queries.map(statement => ({ type: "statement", statement })), + tasks: this.executionSql.createOperationTasks(operation), hermeticity: operation.hermeticity || dataform.ActionHermeticity.NON_HERMETIC }; } @@ -104,9 +101,7 @@ export class Builder { return { ...this.toPartialExecutionAction(assertion), type: "assertion", - tasks: assertion.disabled - ? [] - : this.executionSql.assertTasks(assertion, this.prunedGraph.projectConfig).build(), + tasks: this.executionSql.createAssertionTasks(assertion), hermeticity: assertion.hermeticity || dataform.ActionHermeticity.HERMETIC }; } @@ -114,11 +109,17 @@ export class Builder { private toPartialExecutionAction( action: dataform.ITable | dataform.IOperation | dataform.IAssertion ) { - return dataform.ExecutionAction.create({ + const jitCode = (action as any).jitCode; + const executionAction = dataform.ExecutionAction.create({ target: action.target, fileName: action.fileName, dependencyTargets: action.dependencyTargets, - actionDescriptor: action.actionDescriptor + actionDescriptor: action.actionDescriptor, + disabled: action.disabled }); + if (jitCode) { + executionAction.jitCode = jitCode; + } + return executionAction; } } diff --git a/cli/api/commands/compile.ts b/cli/api/commands/compile.ts index ac863a9a3..6e144a5d0 100644 --- a/cli/api/commands/compile.ts +++ b/cli/api/commands/compile.ts @@ -4,8 +4,10 @@ import * as path from "path"; import * as tmp from "tmp"; import { promisify } from "util"; +import { BaseWorker } from "df/cli/api/commands/base_worker"; import { MISSING_CORE_VERSION_ERROR } from "df/cli/api/commands/install"; import { readConfigFromWorkflowSettings } from "df/cli/api/utils"; +import { DEFAULT_COMPILATION_TIMEOUT_MILLIS } from "df/cli/api/utils/constants"; import { coerceAsError } from "df/common/errors/errors"; import { decode64 } from "df/common/protos"; import { dataform } from "df/protos/ts"; @@ -86,7 +88,7 @@ export async function compile( compileConfig.projectDir = temporaryProjectPath; } - const result = await CompileChildProcess.forkProcess().compile(compileConfig); + const result = await new CompileChildProcess().compile(compileConfig); const decodedResult = decode64(dataform.CoreExecutionResponse, result); compiledGraph = dataform.CompiledGraph.create(decodedResult.compile.compiledGraph); @@ -98,68 +100,24 @@ export async function compile( return compiledGraph; } -export class CompileChildProcess { - public static forkProcess() { - // Runs the worker_bundle script we generate for the package (see packages/@dataform/cli/BUILD) - // if it exists, otherwise run the bazel compile loader target. - const findForkScript = () => { - try { - const workerBundlePath = require.resolve("./worker_bundle.js"); - return workerBundlePath; - } catch (e) { - return require.resolve("../../vm/compile_loader"); - } - }; - const forkScript = findForkScript(); - return new CompileChildProcess( - fork(require.resolve(forkScript), [], { stdio: [0, 1, 2, "ipc", "pipe"] }) - ); - } - private readonly childProcess: ChildProcess; - - constructor(childProcess: ChildProcess) { - this.childProcess = childProcess; +export class CompileChildProcess extends BaseWorker { + constructor() { + super(path.resolve(__dirname, "../../vm/compile_loader")); } public async compile(compileConfig: dataform.ICompileConfig) { - const compileInChildProcess = new Promise(async (resolve, reject) => { - this.childProcess.on("error", (e: Error) => reject(coerceAsError(e))); - - this.childProcess.on("message", (messageOrError: string | Error) => { - if (typeof messageOrError === "string") { - resolve(messageOrError); - return; - } - reject(coerceAsError(messageOrError)); - }); - - this.childProcess.on("close", exitCode => { - if (exitCode !== 0) { - reject(new Error(`Compilation child process exited with exit code ${exitCode}.`)); + const timeoutValue = compileConfig.timeoutMillis || DEFAULT_COMPILATION_TIMEOUT_MILLIS; + + return await this.runWorker( + timeoutValue, + child => child.send(compileConfig), + (message, child, resolve, reject) => { + if (typeof message === "string") { + resolve(message); + } else { + reject(coerceAsError(message)); } - }); - - // Trigger the child process to start compiling. - this.childProcess.send(compileConfig); - }); - let timer; - const timeout = new Promise( - (resolve, reject) => - (timer = setTimeout( - () => reject(new CompilationTimeoutError("Compilation timed out")), - compileConfig.timeoutMillis || 5000 - )) - ); - try { - await Promise.race([timeout, compileInChildProcess]); - return await compileInChildProcess; - } finally { - if (!this.childProcess.killed) { - this.childProcess.kill("SIGKILL"); } - if (timer) { - clearTimeout(timer); - } - } + ); } } diff --git a/cli/api/commands/jit/compiler.ts b/cli/api/commands/jit/compiler.ts new file mode 100644 index 000000000..651da96f8 --- /dev/null +++ b/cli/api/commands/jit/compiler.ts @@ -0,0 +1,103 @@ +import { ChildProcess } from "child_process"; +import * as path from "path"; + +import { BaseWorker } from "df/cli/api/commands/base_worker"; +import { handleDbRequest } from "df/cli/api/commands/jit/rpc"; +import { IDbAdapter, IDbClient } from "df/cli/api/dbadapters"; +import { IBigQueryExecutionOptions } from "df/cli/api/dbadapters/bigquery"; +import { DEFAULT_COMPILATION_TIMEOUT_MILLIS } from "df/cli/api/utils/constants"; +import { dataform } from "df/protos/ts"; + +export interface IJitWorkerMessage { + type: "rpc_request" | "jit_response" | "jit_error"; + method?: string; + request?: Uint8Array; + correlationId?: string; + response?: Uint8Array; + error?: string; +} + +export class JitCompileChildProcess extends BaseWorker< + dataform.IJitCompilationResponse, + IJitWorkerMessage +> { + public static async compile( + request: dataform.IJitCompilationRequest, + projectDir: string, + dbadapter: IDbAdapter, + dbclient: IDbClient, + timeoutMillis: number = DEFAULT_COMPILATION_TIMEOUT_MILLIS, + options?: IBigQueryExecutionOptions + ): Promise { + return await new JitCompileChildProcess().run( + request, + projectDir, + dbadapter, + dbclient, + timeoutMillis, + options + ); + } + + constructor() { + super(path.resolve(__dirname, "../../../vm/jit_loader")); + } + + private async run( + request: dataform.IJitCompilationRequest, + projectDir: string, + dbadapter: IDbAdapter, + dbclient: IDbClient, + timeoutMillis: number, + options?: IBigQueryExecutionOptions + ): Promise { + return await this.runWorker( + timeoutMillis, + child => { + child.send({ + type: "jit_compile", + request, + projectDir + }); + }, + async (message, child, resolve, reject) => { + if (message.type === "rpc_request") { + await this.handleRpcRequest(message, child, dbadapter, dbclient, options); + } else if (message.type === "jit_response") { + resolve(dataform.JitCompilationResponse.fromObject(message.response)); + } else if (message.type === "jit_error") { + reject(new Error(message.error)); + } + } + ); + } + + private async handleRpcRequest( + message: IJitWorkerMessage, + child: ChildProcess, + dbadapter: IDbAdapter, + dbclient: IDbClient, + options?: IBigQueryExecutionOptions + ) { + try { + const response = await handleDbRequest( + dbadapter, + dbclient, + message.method, + message.request, + options + ); + child.send({ + type: "rpc_response", + correlationId: message.correlationId, + response + }); + } catch (e) { + child.send({ + type: "rpc_response", + correlationId: message.correlationId, + error: e.message + }); + } + } +} diff --git a/cli/api/commands/jit/rpc.ts b/cli/api/commands/jit/rpc.ts new file mode 100644 index 000000000..c9bfea071 --- /dev/null +++ b/cli/api/commands/jit/rpc.ts @@ -0,0 +1,92 @@ +import Long from "long"; + +import { IDbAdapter, IDbClient } from "df/cli/api/dbadapters"; +import { IBigQueryExecutionOptions } from "df/cli/api/dbadapters/bigquery"; +import { Structs } from "df/common/protos/structs"; +import { dataform, google } from "df/protos/ts"; + +export async function handleDbRequest( + dbadapter: IDbAdapter, + dbclient: IDbClient, + method: string, + request: Uint8Array, + options?: IBigQueryExecutionOptions +): Promise { + switch (method) { + case "Execute": + return await handleExecute(dbclient, request, options); + case "ListTables": + return await handleListTables(dbadapter, request); + case "GetTable": + return await handleGetTable(dbadapter, request); + case "DeleteTable": + return await handleDeleteTable(dbadapter, request, options?.dryRun); + default: + throw new Error(`Unrecognized RPC method: ${method}`); + } +} + +async function handleExecute( + dbclient: IDbClient, + request: Uint8Array, + options?: IBigQueryExecutionOptions +): Promise { + const executeRequest = dataform.ExecuteRequest.decode(request); + const executeRequestObj = dataform.ExecuteRequest.toObject(executeRequest, { + defaults: false + }); + const requestOptions = executeRequestObj.bigQueryOptions; + + const results = await dbclient.executeRaw(executeRequest.statement, { + rowLimit: executeRequest.rowLimit ? (executeRequest.rowLimit as Long).toNumber() : undefined, + params: Structs.toObject(executeRequest.params), + bigquery: { + ...options, + ...requestOptions, + labels: { + ...options?.labels, + ...requestOptions?.labels + }, + jobPrefix: [options?.jobPrefix, requestOptions?.jobPrefix].filter(Boolean).join("-") || undefined + } + }); + + return dataform.ExecuteResponse.encode({ + rows: (results.rows || []).map(row => Structs.fromObject(row)), + schemaFields: results.schema || [] + } as any).finish(); +} + +async function handleListTables(dbadapter: IDbAdapter, request: Uint8Array): Promise { + const listTablesRequest = dataform.ListTablesRequest.decode(request); + if (!listTablesRequest.database) { + throw new Error("ListTablesRequest.database must be supplied"); + } + const tablesMetadata = await dbadapter.tables(listTablesRequest.database, listTablesRequest.schema); + const listTablesResponse = dataform.ListTablesResponse.create({ + tables: tablesMetadata + }); + return dataform.ListTablesResponse.encode(listTablesResponse).finish(); +} + +async function handleGetTable(dbadapter: IDbAdapter, request: Uint8Array): Promise { + const getTableRequest = dataform.GetTableRequest.decode(request); + const tableMetadata = await dbadapter.table(getTableRequest.target); + if (!tableMetadata) { + throw new Error(`Table not found: ${JSON.stringify(getTableRequest.target)}`); + } + return dataform.TableMetadata.encode(tableMetadata).finish(); +} + +async function handleDeleteTable( + dbadapter: IDbAdapter, + request: Uint8Array, + dryRun?: boolean +): Promise { + const deleteTableRequest = dataform.DeleteTableRequest.decode(request); + if (dryRun) { + return new Uint8Array(); + } + await dbadapter.deleteTable(deleteTableRequest.target); + return new Uint8Array(); +} diff --git a/cli/api/commands/jit/rpc_test.ts b/cli/api/commands/jit/rpc_test.ts new file mode 100644 index 000000000..f9fbd9bc3 --- /dev/null +++ b/cli/api/commands/jit/rpc_test.ts @@ -0,0 +1,499 @@ +import { expect } from "chai"; +import Long from "long"; +import { anything, capture, instance, mock, verify, when } from "ts-mockito"; + +import { handleDbRequest } from "df/cli/api/commands/jit/rpc"; +import { IDbAdapter, IDbClient } from "df/cli/api/dbadapters"; +import { dataform } from "df/protos/ts"; +import { suite, test } from "df/testing"; + +suite("jit_rpc", () => { + test("Execute RPC maps to client.execute with all options", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const statement = "SELECT * FROM table"; + const executeRequest = dataform.ExecuteRequest.create({ + statement, + rowLimit: Long.fromNumber(100), + byteLimit: Long.fromNumber(1024), + bigQueryOptions: { + interactive: true, + location: "US", + labels: { key: "val" }, + jobPrefix: "prefix", + dryRun: true + } + }); + const encodedRequest = dataform.ExecuteRequest.encode(executeRequest).finish(); + + // Real raw BigQuery f/v format + const rawRows = [ + { + f: [ + { v: "42" }, + { v: "val" }, + { v: "true" }, + { v: null } + ] + } + ]; + + const schema = [ + { name: "num", primitive: dataform.Field.Primitive.INTEGER }, + { name: "str", primitive: dataform.Field.Primitive.STRING }, + { name: "bool", primitive: dataform.Field.Primitive.BOOLEAN }, + { name: "n", primitive: dataform.Field.Primitive.STRING } + ]; + when(mockClient.executeRaw(statement, anything())).thenResolve({ + rows: rawRows, + schema, + metadata: { bigquery: { jobId: "job1" } } + }); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest); + const decoded = dataform.ExecuteResponse.decode(response); + + expect(decoded.rows.length).equals(1); + const row = decoded.rows[0]; + const fList = row.fields.f.listValue.values; + expect(fList[0].structValue.fields.v.stringValue).equals("42"); + expect(fList[1].structValue.fields.v.stringValue).equals("val"); + expect(fList[2].structValue.fields.v.stringValue).equals("true"); + expect(fList[3].structValue.fields.v.nullValue).equals(0); + + expect(decoded.schemaFields.length).equals(4); + expect(decoded.schemaFields[0].name).equals("num"); + expect(decoded.schemaFields[1].name).equals("str"); + + verify(mockClient.executeRaw(statement, anything())).once(); + const capturedArgs = capture(mockClient.executeRaw).last(); + expect(capturedArgs[0]).to.equal(statement); + const capturedOptions = capturedArgs[1]; + expect(capturedOptions.bigquery.location).equals("US"); + expect(capturedOptions.bigquery.labels).deep.equals({ key: "val" }); + expect(capturedOptions.bigquery.jobPrefix).equals("prefix"); + expect(capturedOptions.bigquery.dryRun).equals(true); + }); + + test("DeleteTable RPC calls adapter and returns empty buffer", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const target = { database: "db", schema: "sch", name: "tab" }; + const request = dataform.DeleteTableRequest.create({ target }); + const encodedRequest = dataform.DeleteTableRequest.encode(request).finish(); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "DeleteTable", encodedRequest); + + verify(mockAdapter.deleteTable(anything())).once(); + const capturedTarget = capture(mockAdapter.deleteTable).last()[0]; + expect(dataform.Target.create(capturedTarget)).deep.equals(dataform.Target.create(target)); + expect(response.length).equals(0); + }); + + test("Execute RPC handles null and empty result sets", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const statement = "SELECT null as n"; + const encodedRequest = dataform.ExecuteRequest.encode(dataform.ExecuteRequest.create({ statement })).finish(); + + // Test with a null value + when(mockClient.executeRaw(statement, anything())).thenResolve({ + rows: [ + { + f: [{ v: null }] + } + ], + schema: [{ name: "n", primitive: dataform.Field.Primitive.STRING }], + metadata: {} + }); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest); + const decoded = dataform.ExecuteResponse.decode(response); + + expect(decoded.rows.length).equals(1); + const fListNull = decoded.rows[0].fields.f.listValue.values; + expect(fListNull[0].structValue.fields.v.nullValue).equals(0); // Protobuf NullValue.NULL_VALUE is 0 + + verify(mockClient.executeRaw(statement, anything())).once(); + const capturedArgs1 = capture(mockClient.executeRaw).last(); + expect(capturedArgs1[0]).to.equal(statement); + expect(capturedArgs1[1].bigquery.dryRun).equals(undefined); + + // Test with empty rows + when(mockClient.executeRaw(statement, anything())).thenResolve({ + rows: [], + metadata: {} + }); + + const responseEmpty = await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest); + const decodedEmpty = dataform.ExecuteResponse.decode(responseEmpty); + expect(decodedEmpty.rows.length).equals(0); + + verify(mockClient.executeRaw(statement, anything())).twice(); + const capturedArgs2 = capture(mockClient.executeRaw).last(); + expect(capturedArgs2[0]).to.equal(statement); + expect(capturedArgs2[1].bigquery.dryRun).equals(undefined); + }); + + test("ListTables RPC returns tables from adapter", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const request = dataform.ListTablesRequest.create({ database: "db", schema: "sch" }); + const encodedRequest = dataform.ListTablesRequest.encode(request).finish(); + + const target1 = { database: "db", schema: "sch", name: "table1" }; + const metadata1 = { target: target1, type: dataform.TableMetadata.Type.TABLE } as any; + when(mockAdapter.tables("db", "sch")).thenResolve([metadata1]); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "ListTables", encodedRequest); + const decoded = dataform.ListTablesResponse.decode(response); + + expect(decoded.tables.length).equals(1); + expect(decoded.tables[0].target.name).equals("table1"); + + verify(mockAdapter.tables("db", "sch")).once(); + verify(mockAdapter.table(anything())).never(); + }); + + test("ListTables RPC throws error when database is missing", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + // Request without database + const request = dataform.ListTablesRequest.create({ schema: "sch" }); + const encodedRequest = dataform.ListTablesRequest.encode(request).finish(); + + try { + await handleDbRequest(instance(mockAdapter), instance(mockClient), "ListTables", encodedRequest); + expect.fail("Should have thrown an error"); + } catch (e) { + expect(e.message).to.equal("ListTablesRequest.database must be supplied"); + } + + verify(mockAdapter.tables(anything(), anything())).never(); + }); + + test("GetTable RPC returns metadata from adapter", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const target = { database: "db", schema: "sch", name: "tab" }; + const request = dataform.GetTableRequest.create({ target }); + const encodedRequest = dataform.GetTableRequest.encode(request).finish(); + + when(mockAdapter.table(anything())).thenResolve({ target } as any); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "GetTable", encodedRequest); + const decoded = dataform.TableMetadata.decode(response); + + expect(decoded.target.name).equals("tab"); + verify(mockAdapter.table(anything())).once(); + const capturedTarget = capture(mockAdapter.table).last()[0]; + expect(dataform.Target.create(capturedTarget)).deep.equals(dataform.Target.create(target)); + }); + + test("GetTable RPC throws error when table not found", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const target = { database: "db", schema: "sch", name: "missing" }; + const request = dataform.GetTableRequest.create({ target }); + const encodedRequest = dataform.GetTableRequest.encode(request).finish(); + + // Adapter returns null for missing table + when(mockAdapter.table(anything())).thenResolve(null); + + try { + await handleDbRequest(instance(mockAdapter), instance(mockClient), "GetTable", encodedRequest); + expect.fail("Should have thrown an error"); + } catch (e) { + expect(e.message).to.contain("Table not found"); + expect(e.message).to.contain("missing"); + } + + verify(mockAdapter.table(anything())).once(); + }); + + test("DeleteTable RPC respects global dryRun flag", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const request = dataform.DeleteTableRequest.create({ + target: { database: "db", schema: "sch", name: "tab" } + }); + const encodedRequest = dataform.DeleteTableRequest.encode(request).finish(); + + // Call with dryRun = true + await handleDbRequest(instance(mockAdapter), instance(mockClient), "DeleteTable", encodedRequest, { dryRun: true }); + + // Verify that the adapter method was NOT called + verify(mockAdapter.deleteTable(anything())).never(); + }); + + test("Execute RPC respects global dryRun flag", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const statement = "SELECT 1"; + const encodedRequest = dataform.ExecuteRequest.encode(dataform.ExecuteRequest.create({ statement })).finish(); + + when(mockClient.executeRaw(anything(), anything())).thenResolve({ rows: [], metadata: {} }); + + // Call with dryRun = true + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { dryRun: true }); + + verify(mockClient.executeRaw(statement, anything())).once(); + const capturedArgs = capture(mockClient.executeRaw).last(); + expect(capturedArgs[0]).to.equal(statement); + expect(capturedArgs[1].bigquery.dryRun).to.equal(true); + }); + + test("Throws error for unrecognized RPC method", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + try { + await handleDbRequest(instance(mockAdapter), instance(mockClient), "UnknownMethod", new Uint8Array()); + expect.fail("Should have thrown"); + } catch (e) { + expect(e.message).to.contain("Unrecognized RPC method"); + } + }); + + test("Execute RPC merges global BigQuery options with request-specific options", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + const executeRequest = dataform.ExecuteRequest.create({ + statement, + bigQueryOptions: { + location: "EU", + labels: { request_label: "request_val" }, + jobPrefix: "request-prefix", + dryRun: true + } + }); + const encodedRequest = dataform.ExecuteRequest.encode(executeRequest).finish(); + + const globalOptions = { + labels: { global_label: "global_val" }, + location: "US", // Request should override this to EU + jobPrefix: "global-prefix" + }; + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + await handleDbRequest( + instance(mockAdapter), + instance(mockClient), + "Execute", + encodedRequest, + globalOptions + ); + + verify(mockClient.executeRaw(statement, anything())).once(); + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + + // We expect both labels to be present + expect(capturedOptions.bigquery.labels).deep.equals({ + global_label: "global_val", + request_label: "request_val" + }); + // We expect request location to override global location + expect(capturedOptions.bigquery.location).equals("EU"); + // We expect job prefixes to be merged + expect(capturedOptions.bigquery.jobPrefix).equals("global-prefix-request-prefix"); + }); + + test("Execute RPC label merging: both global and request labels", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + const encodedRequest = dataform.ExecuteRequest.encode({ + statement, + bigQueryOptions: { labels: { request_label: "request_val" } } + }).finish(); + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { + labels: { + global_label: "global_val", + request_label: "global_override_attempt" + } + }); + + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + expect(capturedOptions.bigquery.labels).deep.equals({ + global_label: "global_val", + request_label: "request_val" + }); + }); + + test("Execute RPC label merging: undefined global, defined request", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + const encodedRequest = dataform.ExecuteRequest.encode({ + statement, + bigQueryOptions: { labels: { request_label: "request_val" } } + }).finish(); + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + // Global options have no labels + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { + location: "US" + }); + + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + expect(capturedOptions.bigquery.labels).deep.equals({ + request_label: "request_val" + }); + }); + + test("Execute RPC label merging: empty global, defined request", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + const encodedRequest = dataform.ExecuteRequest.encode({ + statement, + bigQueryOptions: { labels: { request_label: "request_val" } } + }).finish(); + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + // Global options have empty labels object + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { + labels: {} + }); + + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + expect(capturedOptions.bigquery.labels).deep.equals({ + request_label: "request_val" + }); + }); + + test("Execute RPC label merging: defined global, undefined request", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + // Request has no labels + const encodedRequest = dataform.ExecuteRequest.encode({ + statement, + bigQueryOptions: { location: "US" } + }).finish(); + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { + labels: { global_label: "global_val" } + }); + + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + expect(capturedOptions.bigquery.labels).deep.equals({ + global_label: "global_val" + }); + }); + + test("Execute RPC label merging: defined global, empty request", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + const statement = "SELECT 1"; + // Request has empty labels + const encodedRequest = dataform.ExecuteRequest.encode({ + statement, + bigQueryOptions: { labels: {} } + }).finish(); + + when(mockClient.executeRaw(statement, anything())).thenResolve({ rows: [], metadata: {} }); + + await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest, { + labels: { global_label: "global_val" } + }); + + const capturedOptions = capture(mockClient.executeRaw).last()[1]; + expect(capturedOptions.bigquery.labels).deep.equals({ + global_label: "global_val" + }); + }); + + test("Execute RPC handles raw BigQuery f,v format results", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const statement = "SELECT * FROM table"; + const encodedRequest = dataform.ExecuteRequest.encode(dataform.ExecuteRequest.create({ statement })).finish(); + + // Real raw BigQuery f/v format + const rawRows = [ + { + f: [ + { v: "42" } + ] + } + ]; + + when(mockClient.executeRaw(statement, anything())).thenResolve({ + rows: rawRows, + schema: [{ name: "id", primitive: dataform.Field.Primitive.STRING }], + metadata: { bigquery: { jobId: "job1" } } + }); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest); + const decoded = dataform.ExecuteResponse.decode(response); + + expect(decoded.rows.length).equals(1); + const row = decoded.rows[0]; + expect(row.fields.f).to.not.equal(undefined); + const fList = row.fields.f.listValue.values; + expect(fList[0].structValue.fields.v.stringValue).equals("42"); + expect(decoded.schemaFields.length).equals(1); + expect(decoded.schemaFields[0].name).equals("id"); + }); + + test("Execute RPC preserves complex nested BigQuery f,v format", async () => { + const mockAdapter = mock(); + const mockClient = mock(); + + const statement = "SELECT complex_struct FROM table"; + const encodedRequest = dataform.ExecuteRequest.encode(dataform.ExecuteRequest.create({ statement })).finish(); + + // Real raw BigQuery complex nested f/v format + const rawRows = [ + { + f: [ + { + v: { + f: [ + { v: "nested_val" }, + { v: "123" } + ] + } + } + ] + } + ]; + + when(mockClient.executeRaw(statement, anything())).thenResolve({ + rows: rawRows, + schema: [{ name: "complex_struct", primitive: dataform.Field.Primitive.STRING }], + metadata: { bigquery: { jobId: "job1" } } + }); + + const response = await handleDbRequest(instance(mockAdapter), instance(mockClient), "Execute", encodedRequest); + const decoded = dataform.ExecuteResponse.decode(response); + + expect(decoded.rows.length).equals(1); + const row = decoded.rows[0]; + const nestedStruct = row.fields.f.listValue.values[0].structValue.fields.v.structValue; + const nestedFList = nestedStruct.fields.f.listValue.values; + expect(nestedFList[0].structValue.fields.v.stringValue).equals("nested_val"); + expect(decoded.schemaFields.length).equals(1); + expect(decoded.schemaFields[0].name).equals("complex_struct"); + }); +}); diff --git a/cli/api/commands/run.ts b/cli/api/commands/run.ts index f618aa09d..6074e687b 100644 --- a/cli/api/commands/run.ts +++ b/cli/api/commands/run.ts @@ -1,12 +1,16 @@ import EventEmitter from "events"; import Long from "long"; +import { JitCompileChildProcess } from "df/cli/api/commands/jit/compiler"; import * as dbadapters from "df/cli/api/dbadapters"; import { IBigQueryExecutionOptions } from "df/cli/api/dbadapters/bigquery"; +import { ExecutionSql } from "df/cli/api/dbadapters/execution_sql"; +import { DEFAULT_COMPILATION_TIMEOUT_MILLIS } from "df/cli/api/utils/constants"; import { Flags } from "df/common/flags"; import { retry } from "df/common/promises"; import { deepClone, equals } from "df/common/protos"; import { targetStringifier } from "df/core/targets"; +import { version } from "df/core/version"; import { dataform } from "df/protos/ts"; const CANCEL_EVENT = "jobCancel"; @@ -24,18 +28,29 @@ export interface IExecutedAction { } export interface IExecutionOptions { + projectDir?: string; bigquery?: { jobPrefix?: string; actionRetryLimit?: number; dryRun?: boolean; labels?: { [label: string]: string }; }; + jitCompiler?: ( + request: dataform.IJitCompilationRequest, + projectDir: string, + dbadapter: dbadapters.IDbAdapter, + dbclient: dbadapters.IDbClient, + timeoutMillis?: number, + options?: IBigQueryExecutionOptions + ) => Promise; } +export type RunOptionsOrProjectDir = string | IExecutionOptions | dataform.IRunResult; + export function run( dbadapter: dbadapters.IDbAdapter, graph: dataform.IExecutionGraph, - executionOptions?: IExecutionOptions, + executionOptions: RunOptionsOrProjectDir = ".", partiallyExecutedRunResult: dataform.IRunResult = {}, runnerNotificationPeriodMillis: number = flags.runnerNotificationPeriodMillis.get() ): Runner { @@ -49,7 +64,42 @@ export function run( } export class Runner { + private static handleParamsOverloads( + executionOptions: RunOptionsOrProjectDir, + partiallyExecutedRunResult: dataform.IRunResult + ): { options: IExecutionOptions; runResult: dataform.IRunResult } { + if ( + typeof executionOptions === "object" && + executionOptions !== null && + "actions" in executionOptions + ) { + return { + runResult: { + actions: [], + ...(executionOptions as dataform.IRunResult) + }, + options: { projectDir: "." } + }; + } + const options = + typeof executionOptions === "string" + ? { projectDir: executionOptions } + : { ...(executionOptions as IExecutionOptions) }; + if (!options.projectDir) { + options.projectDir = "."; + } + return { + options, + runResult: { + actions: [], + ...partiallyExecutedRunResult + } + }; + } + private readonly warehouseStateByTarget: Map; + private readonly executionSql: ExecutionSql; + private readonly executionOptions: IExecutionOptions; private readonly allActionTargets: Set; private readonly runResult: dataform.IRunResult; @@ -68,17 +118,24 @@ export class Runner { constructor( private readonly dbadapter: dbadapters.IDbAdapter, private readonly graph: dataform.IExecutionGraph, - private readonly executionOptions: IExecutionOptions = {}, + optionsOrProjectDir: RunOptionsOrProjectDir = ".", partiallyExecutedRunResult: dataform.IRunResult = {}, private readonly runnerNotificationPeriodMillis: number = flags.runnerNotificationPeriodMillis.get() ) { + const { options, runResult } = Runner.handleParamsOverloads( + optionsOrProjectDir, + partiallyExecutedRunResult + ); + this.executionOptions = options; + this.runResult = runResult; + + if (!this.executionOptions.jitCompiler) { + this.executionOptions.jitCompiler = JitCompileChildProcess.compile; + } + this.executionSql = new ExecutionSql(graph.projectConfig, version); this.allActionTargets = new Set( graph.actions.map(action => targetStringifier.stringify(action.target)) ); - this.runResult = { - actions: [], - ...partiallyExecutedRunResult - }; this.warehouseStateByTarget = new Map(); graph.warehouseState.tables?.forEach(tableMetadata => this.warehouseStateByTarget.set( @@ -303,13 +360,28 @@ export class Runner { ]); } + private getBigQueryExecutionOptions(action: dataform.IExecutionAction): IBigQueryExecutionOptions { + return { + dryRun: this.executionOptions.bigquery?.dryRun, + jobPrefix: this.executionOptions.bigquery?.jobPrefix, + labels: { + ...(this.executionOptions?.bigquery?.labels || {}), + ...(action.actionDescriptor?.bigqueryLabels || {}) + }, + actionRetryLimit: this.executionOptions.bigquery?.actionRetryLimit, + reservation: + action.actionDescriptor?.reservation || + this.graph.projectConfig?.defaultReservation + }; + } + private async executeAction(action: dataform.IExecutionAction): Promise { let actionResult: dataform.IActionResult = { target: action.target, tasks: [] }; - if (action.tasks.length === 0) { + if ((action.tasks.length === 0 && !action.jitCode) || action.disabled) { actionResult.status = dataform.ActionResult.ExecutionStatus.DISABLED; this.runResult.actions.push(actionResult); this.notifyListeners(); @@ -329,7 +401,14 @@ export class Runner { actionResult.timing = timer.current(); this.notifyListeners(); - await this.dbadapter.withClientLock(async client => { + try { + if (action.jitCode) { + await this.compileJitAction(action, actionResult, this.dbadapter); + if ((actionResult.status as dataform.ActionResult.ExecutionStatus) === dataform.ActionResult.ExecutionStatus.FAILED) { + return actionResult; + } + } + // Start running tasks from the last executed task (if any), onwards. for (const task of action.tasks.slice(actionResult.tasks.length)) { if (this.stopped) { @@ -339,20 +418,8 @@ export class Runner { actionResult.status === dataform.ActionResult.ExecutionStatus.RUNNING && !this.cancelled ) { - const taskStatus = await this.executeTask(client, task, actionResult, { - bigquery: { - // Merge global run-level labels with action-level labels. Action-level labels take precedence. - labels: { - ...(this.executionOptions?.bigquery?.labels || {}), - ...(action.actionDescriptor?.bigqueryLabels || {}) - }, - actionRetryLimit: this.executionOptions?.bigquery?.actionRetryLimit, - jobPrefix: this.executionOptions?.bigquery?.jobPrefix, - dryRun: this.executionOptions?.bigquery?.dryRun, - reservation: - action.actionDescriptor?.reservation || - this.graph.projectConfig?.defaultReservation - } + const taskStatus = await this.executeTask(this.dbadapter, task, actionResult, { + bigquery: this.getBigQueryExecutionOptions(action) }); if (taskStatus === dataform.TaskResult.ExecutionStatus.FAILED) { actionResult.status = dataform.ActionResult.ExecutionStatus.FAILED; @@ -365,7 +432,15 @@ export class Runner { }); } } - }); + } catch (e) { + if ((actionResult.status as dataform.ActionResult.ExecutionStatus) !== dataform.ActionResult.ExecutionStatus.FAILED) { + actionResult.status = dataform.ActionResult.ExecutionStatus.FAILED; + actionResult.tasks.push({ + status: dataform.TaskResult.ExecutionStatus.FAILED, + errorMessage: `Unexpected execution error: ${e.message}` + }); + } + } if (this.stopped) { return actionResult; @@ -423,6 +498,10 @@ export class Runner { }; parentAction.tasks.push(taskResult); this.notifyListeners(); + if (options.bigquery?.dryRun) { + taskResult.compiledSql = task.statement; + + } if (options.bigquery?.dryRun && task.type === "assertion") { taskResult.status = dataform.TaskResult.ExecutionStatus.SUCCESSFUL; } @@ -466,6 +545,76 @@ export class Runner { this.notifyListeners(); return taskResult.status; } + + private async compileJitAction( + action: dataform.IExecutionAction, + actionResult: dataform.IActionResult, + client: dbadapters.IDbClient + ) { + let compilationTargetType = + dataform.JitCompilationTargetType.JIT_COMPILATION_TARGET_TYPE_UNSPECIFIED; + if (action.type === "table") { + compilationTargetType = + action.tableType === "incremental" + ? dataform.JitCompilationTargetType.JIT_COMPILATION_TARGET_TYPE_INCREMENTAL_TABLE + : dataform.JitCompilationTargetType.JIT_COMPILATION_TARGET_TYPE_TABLE; + } else if (action.type === "operation") { + compilationTargetType = dataform.JitCompilationTargetType.JIT_COMPILATION_TARGET_TYPE_OPERATION; + } + + const jitRequest = dataform.JitCompilationRequest.create({ + target: action.target, + dependencies: action.dependencyTargets, + jitCode: action.jitCode, + fileName: action.fileName, + compilationTargetType, + jitData: this.graph.jitData + }); + + const timeoutMillis = this.graph.runConfig?.timeoutMillis || DEFAULT_COMPILATION_TIMEOUT_MILLIS; + + try { + const jitResponse = await this.executionOptions.jitCompiler( + jitRequest, + this.executionOptions.projectDir, + this.dbadapter, + client, + timeoutMillis, + this.getBigQueryExecutionOptions(action) + ); + + if (jitResponse.table) { + const table = dataform.Table.create({ + ...action, + ...jitResponse.table, + enumType: action.tableType === "view" ? dataform.TableType.VIEW : (action.tableType === "incremental" ? dataform.TableType.INCREMENTAL : dataform.TableType.TABLE) + }); + action.tasks = this.executionSql.createTableTasks(table, this.graph.runConfig, this.warehouseStateByTarget.get(targetStringifier.stringify(action.target))); + } else if (jitResponse.operation) { + const operation = dataform.Operation.create({ + ...action, + ...jitResponse.operation + }); + action.tasks = this.executionSql.createOperationTasks(operation); + } else if (jitResponse.incrementalTable) { + const table = dataform.Table.create({ + ...action, + ...jitResponse.incrementalTable.regular, + incrementalQuery: jitResponse.incrementalTable.incremental?.query, + incrementalPreOps: jitResponse.incrementalTable.incremental?.preOps, + incrementalPostOps: jitResponse.incrementalTable.incremental?.postOps, + enumType: dataform.TableType.INCREMENTAL + }); + action.tasks = this.executionSql.createTableTasks(table, this.graph.runConfig, this.warehouseStateByTarget.get(targetStringifier.stringify(action.target))); + } + } catch (e) { + actionResult.status = dataform.ActionResult.ExecutionStatus.FAILED; + actionResult.tasks.push({ + status: dataform.TaskResult.ExecutionStatus.FAILED, + errorMessage: `JiT compilation error: ${e.message}` + }); + } + } } class Timer { diff --git a/cli/api/dbadapters/bigquery.ts b/cli/api/dbadapters/bigquery.ts index 6f02bc609..1450070fc 100644 --- a/cli/api/dbadapters/bigquery.ts +++ b/cli/api/dbadapters/bigquery.ts @@ -3,7 +3,14 @@ import Long from "long"; import { PromisePoolExecutor } from "promise-pool-executor"; import { collectEvaluationQueries, QueryOrAction } from "df/cli/api/dbadapters/execution_sql"; -import { IBigQueryError, IDbAdapter, IDbClient, IExecutionResult, OnCancel } from "df/cli/api/dbadapters/index"; +import { + IBigQueryError, + IDbAdapter, + IDbClient, + IExecutionResult, + IExecutionResultRaw, + OnCancel +} from "df/cli/api/dbadapters/index"; import { parseBigqueryEvalError } from "df/cli/api/utils/error_parsing"; import { LimitedResultSet } from "df/cli/api/utils/results"; import { coerceAsError } from "df/common/errors/errors"; @@ -30,14 +37,55 @@ export interface IBigQueryExecutionOptions { reservation?: string; } +export interface IBigQueryClientProvider { + get(projectId?: string): BigQuery; +} + +export class BigQueryClientProvider implements IBigQueryClientProvider { + private readonly clients = new Map(); + + constructor(private readonly credentials: dataform.IBigQuery) {} + + public get(projectId?: string): BigQuery { + projectId = projectId || this.credentials.projectId; + if (!this.clients.has(projectId)) { + this.clients.set( + projectId, + new BigQuery({ + projectId, + scopes: EXTRA_GOOGLE_SCOPES, + location: this.credentials.location, + credentials: this.credentials.credentials && JSON.parse(this.credentials.credentials) + }) + ); + } + return this.clients.get(projectId); + } +} + +export class StaticBigQueryClientProvider implements IBigQueryClientProvider { + constructor(private readonly client: BigQuery) {} + + public get(projectId?: string): BigQuery { + return this.client; + } +} + export class BigQueryDbAdapter implements IDbAdapter { private bigQueryCredentials: dataform.IBigQuery; private pool: PromisePoolExecutor; + private clientProvider: IBigQueryClientProvider; - private readonly clients = new Map(); - - constructor(credentials: dataform.IBigQuery, options?: { concurrencyLimit: number }) { + constructor( + credentials: dataform.IBigQuery, + options?: { + concurrencyLimit?: number; + clientProvider?: IBigQueryClientProvider; + } + ) { this.bigQueryCredentials = credentials; + this.clientProvider = options?.clientProvider || new BigQueryClientProvider(credentials); + // Bigquery allows 50 concurrent queries, and a rate limit of 100/user/second by default. // These limits should be safely low enough for most projects. this.pool = new PromisePoolExecutor({ @@ -92,8 +140,29 @@ export class BigQueryDbAdapter implements IDbAdapter { .promise(); } - public async withClientLock(callback: (client: IDbClient) => Promise) { - return await callback(this); + public async executeRaw( + statement: string, + options: { + params?: { [name: string]: any }; + rowLimit?: number; + bigquery?: IBigQueryExecutionOptions; + } = { rowLimit: 1000 } + ): Promise { + if (!statement) { + throw new Error("Query string cannot be empty"); + } + return this.pool + .addSingleTask({ + generator: async () => { + const [rows, , apiResponse] = await this.getClient().query({ + ...this.prepareQueryOptions(statement, options.rowLimit, options.bigquery, options.params), + skipParsing: true + } as any); + const schema = apiResponse?.schema?.fields?.map((field: any) => convertField(field)); + return { rows, schema, metadata: {} }; + } + }) + .promise(); } public async evaluate(queryOrAction: QueryOrAction) { @@ -129,22 +198,31 @@ export class BigQueryDbAdapter implements IDbAdapter { ); } - public async tables(): Promise { - const datasets = await this.getClient().getDatasets({ autoPaginate: true, maxResults: 1000 }); - const tables = await Promise.all( - datasets[0].map(dataset => dataset.getTables({ autoPaginate: true, maxResults: 1000 })) - ); - const allTables: dataform.ITarget[] = []; - tables.forEach((tablesResult: GetTablesResponse) => - tablesResult[0].forEach(table => - allTables.push({ - database: table.bigQuery.projectId, - schema: table.dataset.id, - name: table.id - }) - ) + public async tables(database: string, schema?: string): Promise { + const datasetIds = schema ? [schema] : await this.schemas(database); + const tablesMetadata: dataform.ITableMetadata[] = []; + + await Promise.all( + datasetIds.map(async datasetId => { + const [tables] = await this.getClient(database) + .dataset(datasetId) + .getTables({ autoPaginate: true, maxResults: 1000 }); + await Promise.all( + tables.map(async table => { + const metadata = await this.table({ + database, + schema: datasetId, + name: table.id + }); + if (metadata) { + tablesMetadata.push(metadata); + } + }) + ); + }) ); - return allTables; + + return tablesMetadata; } public async search( @@ -218,8 +296,15 @@ export class BigQueryDbAdapter implements IDbAdapter { }); } + public async deleteTable(target: dataform.ITarget): Promise { + await this.getClient(target.database) + .dataset(target.schema) + .table(target.name) + .delete({ ignoreNotFound: true }); + } + public async schemas(database: string): Promise { - const data = await this.getClient(database).getDatasets(); + const data = await this.getClient(database).getDatasets({ autoPaginate: true, maxResults: 1000 }); return data[0].map(dataset => dataset.id); } @@ -267,20 +352,7 @@ export class BigQueryDbAdapter implements IDbAdapter { } private getClient(projectId?: string) { - projectId = projectId || this.bigQueryCredentials.projectId; - if (!this.clients.has(projectId)) { - this.clients.set( - projectId, - new BigQuery({ - projectId, - scopes: EXTRA_GOOGLE_SCOPES, - location: this.bigQueryCredentials.location, - credentials: - this.bigQueryCredentials.credentials && JSON.parse(this.bigQueryCredentials.credentials) - }) - ); - } - return this.clients.get(projectId); + return this.clientProvider.get(projectId); } private async runQuery( @@ -289,7 +361,7 @@ export class BigQueryDbAdapter implements IDbAdapter { rowLimit?: number, byteLimit?: number, location?: string - ) { + ): Promise { const results = await new Promise((resolve, reject) => { const allRows = new LimitedResultSet({ rowLimit, @@ -314,6 +386,25 @@ export class BigQueryDbAdapter implements IDbAdapter { return { rows: cleanRows(results), metadata: {} }; } + private prepareQueryOptions( + query: string, + rowLimit?: number, + bigqueryOptions?: IBigQueryExecutionOptions, + params?: { [name: string]: any } + ) { + return { + query, + useLegacySql: false, + jobPrefix: "dataform-" + (bigqueryOptions?.jobPrefix ? `${bigqueryOptions.jobPrefix}-` : ""), + location: bigqueryOptions?.location, + maxResults: rowLimit, + labels: bigqueryOptions?.labels, + dryRun: bigqueryOptions?.dryRun, + reservation: bigqueryOptions?.reservation, + params + }; + } + private async createQueryJob( query: string, params?: { [name: string]: any }, @@ -325,23 +416,27 @@ export class BigQueryDbAdapter implements IDbAdapter { jobPrefix?: string, dryRun?: boolean, reservation?: string - ) { + ): Promise { let isCancelled = false; onCancel?.(() => (isCancelled = true)); return retry( async () => { try { - const job = await this.getClient().createQueryJob({ - useLegacySql: false, - jobPrefix: "dataform-" + (jobPrefix ? `${jobPrefix}-` : ""), - query, - params, - labels, - location, - dryRun, - reservation - } as any); + const job = await this.getClient().createQueryJob( + this.prepareQueryOptions( + query, + rowLimit, + { + labels, + location, + jobPrefix, + dryRun, + reservation + }, + params + ) as any + ); const resultStream = job[0].getQueryResultsStream(); return new Promise((resolve, reject) => { if (isCancelled) { @@ -465,11 +560,14 @@ function convertFieldType(type: string) { case "INT64": return dataform.Field.Primitive.INTEGER; case "NUMERIC": + case "BIGNUMERIC": return dataform.Field.Primitive.NUMERIC; case "BOOL": case "BOOLEAN": return dataform.Field.Primitive.BOOLEAN; case "STRING": + case "JSON": + case "INTERVAL": return dataform.Field.Primitive.STRING; case "DATE": return dataform.Field.Primitive.DATE; @@ -492,6 +590,9 @@ function addDescriptionToMetadata( columnDescriptions: dataform.IColumnDescriptor[], metadataArray: TableField[] ): TableField[] { + if (!columnDescriptions) { + return metadataArray; + } const findDescription = (path: string[]) => columnDescriptions.find(column => column.path.join("") === path.join("")); diff --git a/cli/api/dbadapters/bigquery_test.ts b/cli/api/dbadapters/bigquery_test.ts new file mode 100644 index 000000000..37012824b --- /dev/null +++ b/cli/api/dbadapters/bigquery_test.ts @@ -0,0 +1,145 @@ +import { Dataset, Table } from "@google-cloud/bigquery"; +import { expect } from "chai"; +import { anything, instance, mock, verify, when } from "ts-mockito"; + +import { BigQueryDbAdapter, StaticBigQueryClientProvider } from "df/cli/api/dbadapters/bigquery"; +import { dataform } from "df/protos/ts"; +import { suite, test } from "df/testing"; + +suite("BigQueryDbAdapter", () => { + test("tables() with schema filters correctly", async () => { + const mockBigQuery = mock(); + const mockDataset = mock(); + const mockTable = mock(); + + const tableName = "table1"; + const schemaName = "schema1"; + const projectId = "project1"; + + const credentials = dataform.BigQuery.create({ projectId, location: "US" }); + const adapter = new BigQueryDbAdapter(credentials, { clientProvider: new StaticBigQueryClientProvider(instance(mockBigQuery)) }); + + when(mockBigQuery.dataset(schemaName)).thenReturn(instance(mockDataset)); + // getTables returns an array where the first element is an array of tables. + // Each table object needs an 'id' property. + when(mockDataset.getTables(anything())).thenReturn(Promise.resolve([[{ id: tableName }]] as any)); + when(mockDataset.table(tableName)).thenReturn(instance(mockTable)); + when(mockTable.getMetadata()).thenReturn( + Promise.resolve([ + { + type: "TABLE", + tableReference: { projectId, datasetId: schemaName, tableId: tableName }, + schema: { fields: [{ name: "col1", type: "STRING", mode: "NULLABLE" }] }, + lastModifiedTime: "123456789" + } + ] as any) + ); + + const result = await adapter.tables(projectId, schemaName); + + expect(result.length).to.equal(1); + expect(result[0].target.database).to.equal(projectId); + expect(result[0].target.schema).to.equal(schemaName); + expect(result[0].target.name).to.equal(tableName); + expect(result[0].fields.length).to.equal(1); + expect(result[0].fields[0].name).to.equal("col1"); + }); + + test("tables() without schema lists all datasets and tables", async () => { + const mockBigQuery = mock(); + const mockDataset = mock(); + const mockTable = mock
(); + const schemaName = "schema1"; + const tableName = "table1"; + const projectId = "project"; + + const credentials = dataform.BigQuery.create({ projectId, location: "US" }); + const adapter = new BigQueryDbAdapter(credentials, { clientProvider: new StaticBigQueryClientProvider(instance(mockBigQuery)) }); + + when(mockBigQuery.dataset(schemaName)).thenReturn(instance(mockDataset)); + when(mockDataset.getTables(anything())).thenReturn(Promise.resolve([[{ id: tableName }]] as any)); + when(mockDataset.table(tableName)).thenReturn(instance(mockTable)); + when(mockTable.getMetadata()).thenReturn( + Promise.resolve([ + { + type: "TABLE", + tableReference: { projectId, datasetId: schemaName, tableId: tableName }, + schema: { fields: [{ name: "col1", type: "STRING" }] }, + lastModifiedTime: "123456789" + } + ] as any) + ); + + when(mockBigQuery.getDatasets(anything())).thenReturn(Promise.resolve([[{ id: schemaName }]] as any)); + + const result = await adapter.tables(projectId); + + expect(result.length).to.equal(1); + expect(result[0].target.database).to.equal(projectId); + expect(result[0].target.schema).to.equal(schemaName); + expect(result[0].target.name).to.equal(tableName); + }); + + test("setMetadata handles action without columns", async () => { + // Partial mock for BigQuery client to avoid real network calls + const mockBigQuery: any = { + dataset: () => ({ + table: () => ({ + getMetadata: () => Promise.resolve([{ schema: { fields: [] } }]), + setMetadata: (metadata: any) => { + expect(metadata.description).to.equal("test"); + return Promise.resolve([]); + } + }) + }) + }; + + const credentials = dataform.BigQuery.create({ projectId: "p", location: "US" }); + const adapter = new BigQueryDbAdapter(credentials, { + concurrencyLimit: 1, + clientProvider: { get: () => mockBigQuery } + }); + + const action = dataform.ExecutionAction.create({ + target: { database: "db", schema: "sch", name: "tab" }, + actionDescriptor: { description: "test" } + // columns is missing/null in this action + }); + + // This should not throw "cannot read property 'find' of undefined" + await adapter.setMetadata(action); + }); + + test("setMetadata correctly maps column descriptions", async () => { + const mockBigQuery: any = { + dataset: () => ({ + table: () => ({ + getMetadata: () => Promise.resolve([{ + schema: { + fields: [{ name: "id", type: "INTEGER" }] + } + }]), + setMetadata: (metadata: any) => { + expect(metadata.schema[0].description).to.equal("id desc"); + return Promise.resolve([]); + } + }) + }) + }; + + const credentials = dataform.BigQuery.create({ projectId: "p", location: "US" }); + const adapter = new BigQueryDbAdapter(credentials, { + concurrencyLimit: 1, + clientProvider: { get: () => mockBigQuery } + }); + + const action = dataform.ExecutionAction.create({ + target: { database: "db", schema: "sch", name: "tab" }, + actionDescriptor: { + columns: [{ path: ["id"], description: "id desc" }] + } + }); + + await adapter.setMetadata(action); + }); +}); diff --git a/cli/api/dbadapters/execution_sql.ts b/cli/api/dbadapters/execution_sql.ts index f226b9e91..e6655514b 100644 --- a/cli/api/dbadapters/execution_sql.ts +++ b/cli/api/dbadapters/execution_sql.ts @@ -168,9 +168,29 @@ from (${query}) as insertions`; return tasks.concatenate(); } + public createTableTasks( + table: dataform.ITable, + runConfig: dataform.IRunConfig, + tableMetadata?: dataform.ITableMetadata + ): dataform.IExecutionTask[] { + return table.disabled ? [] : this.publishTasks(table, runConfig, tableMetadata).build(); + } + + public createOperationTasks(operation: dataform.IOperation): dataform.IExecutionTask[] { + return operation.disabled + ? [] + : operation.queries.map(statement => + dataform.ExecutionTask.create({ type: "statement", statement }) + ); + } + + public createAssertionTasks(assertion: dataform.IAssertion): dataform.IExecutionTask[] { + return assertion.disabled ? [] : this.assertTasks(assertion, this.project).build(); + } + public assertTasks( assertion: dataform.IAssertion, - projectConfig: dataform.IProjectConfig, + projectConfig: dataform.IProjectConfig ): Tasks { const tasks = new Tasks(); const target = assertion.target; diff --git a/cli/api/dbadapters/index.ts b/cli/api/dbadapters/index.ts index 1fe69dff7..ff55d1fff 100644 --- a/cli/api/dbadapters/index.ts +++ b/cli/api/dbadapters/index.ts @@ -8,6 +8,10 @@ export interface IExecutionResult { metadata: dataform.IExecutionMetadata; } +export interface IExecutionResultRaw extends IExecutionResult { + schema?: dataform.IField[]; +} + export interface IBigQueryError extends Error { metadata?: dataform.IExecutionMetadata } @@ -25,23 +29,37 @@ export interface IDbClient { location?: string; jobPrefix?: string; dryRun?: boolean; + reservation?: string; }; } ): Promise; + + executeRaw( + statement: string, + options?: { + params?: { [name: string]: any }; + rowLimit?: number; + bigquery?: { + labels?: { [label: string]: string }; + location?: string; + jobPrefix?: string; + dryRun?: boolean; + reservation?: string; + }; + } + ): Promise; } export interface IDbAdapter extends IDbClient { - withClientLock(callback: (client: IDbClient) => Promise): Promise; - evaluate(queryOrAction: QueryOrAction): Promise; schemas(database: string): Promise; createSchema(database: string, schema: string): Promise; - // TODO: This should take parameters to allow for retrieving from a specific database/schema. - tables(): Promise; + tables(database: string, schema?: string): Promise; search(searchText: string, options?: { limit: number }): Promise; table(target: dataform.ITarget): Promise; + deleteTable(target: dataform.ITarget): Promise; setMetadata(action: dataform.IExecutionAction): Promise; } diff --git a/cli/api/utils/constants.ts b/cli/api/utils/constants.ts new file mode 100644 index 000000000..21907e9a7 --- /dev/null +++ b/cli/api/utils/constants.ts @@ -0,0 +1 @@ +export const DEFAULT_COMPILATION_TIMEOUT_MILLIS = 300000; diff --git a/cli/console.ts b/cli/console.ts index 63ab55185..864c5f7d9 100644 --- a/cli/console.ts +++ b/cli/console.ts @@ -42,6 +42,22 @@ const writeStdErr = (text: string, indentCount: number = 0) => const DEFAULT_PROMPT = "> "; +export class Logger { + constructor(private readonly shouldLog: boolean) {} + + public log(text: string) { + if (this.shouldLog) { + print(text); + } + } + + public success(text: string) { + if (this.shouldLog) { + printSuccess(text); + } + } +} + export function question(questionText: string) { return prompt(questionText); } @@ -512,9 +528,14 @@ function printExecutedActionErrors( task => task.status === dataform.TaskResult.ExecutionStatus.FAILED ); failingTasks.forEach((task, i) => { - executionAction.tasks[i].statement.split("\n").forEach(line => { - writeStdErr(`${DEFAULT_PROMPT}${line}`, 1); - }); + // For JiT actions, the original executionAction.tasks might be empty + // since they are generated during re-compilation. + const statement = task.compiledSql || executionAction.tasks[i]?.statement; + if (statement) { + statement.split("\n").forEach((line: string) => { + writeStdErr(`${DEFAULT_PROMPT}${line}`, 1); + }); + } printError(task.errorMessage, 1); }); } diff --git a/cli/index.ts b/cli/index.ts index 0ba3e7f0f..bc8d10315 100644 --- a/cli/index.ts +++ b/cli/index.ts @@ -10,6 +10,7 @@ import { CREDENTIALS_FILENAME } from "df/cli/api/commands/credentials"; import { BigQueryDbAdapter } from "df/cli/api/dbadapters/bigquery"; import { prettyJsonStringify } from "df/cli/api/utils"; import { + Logger, print, printCompiledGraph, printCompiledGraphErrors, @@ -393,11 +394,10 @@ export function runCli() { ], processFn: async argv => { const projectDir = argv[projectDirMustExistOption.name]; + const logger = new Logger(!argv[jsonOutputOption.name]); async function compileAndPrint() { - if (!argv[jsonOutputOption.name]) { - print("Compiling...\n"); - } + logger.log("Compiling...\n"); const compiledGraph = await compile({ projectDir, projectConfigOverride: ProjectConfigOptions.constructProjectConfigOverride(argv), @@ -541,7 +541,6 @@ export function runCli() { fullRefreshOption, includeDepsOption, includeDependentsOption, - credentialsOption, jsonOutputOption, timeoutOption, tagsOption, @@ -549,16 +548,17 @@ export function runCli() { ...ProjectConfigOptions.allYargsOptions ], processFn: async argv => { - if (argv[jsonOutputOption.name] && !argv[dryRunOptionName]) { - print( + const isJsonOutput = argv[jsonOutputOption.name]; + const logger = new Logger(!isJsonOutput); + + if (isJsonOutput && !argv[dryRunOptionName]) { + printError( `For execution, the --${jsonOutputOption.name} option is only supported if the ` + `--${dryRunOptionName} option is enabled` ); return; } - if (!argv[jsonOutputOption.name]) { - print("Compiling...\n"); - } + logger.log("Compiling...\n"); const compiledGraph = await compile({ projectDir: argv[projectDirOption.name], projectConfigOverride: ProjectConfigOptions.constructProjectConfigOverride(argv), @@ -568,9 +568,7 @@ export function runCli() { printCompiledGraphErrors(compiledGraph.graphErrors, argv[quietCompileOption.name]); return 1; } - if (!argv[jsonOutputOption.name]) { - printSuccess("Compiled successfully.\n"); - } + logger.success("Compiled successfully.\n"); const readCredentials = credentials.read( getCredentialsPath(argv[projectDirOption.name], argv[credentialsOption.name]) ); @@ -583,25 +581,30 @@ export function runCli() { actions: argv[actionsOption.name], includeDependencies: argv[includeDepsOption.name], includeDependents: argv[includeDependentsOption.name], - tags: argv[tagsOption.name] + tags: argv[tagsOption.name], + timeoutMillis: argv[timeoutOption.name] || undefined }, dbadapter ); - if (argv[dryRunOptionName] && argv[jsonOutputOption.name]) { - printExecutionGraph(executionGraph, argv[jsonOutputOption.name]); + if ( + argv[dryRunOptionName] && + isJsonOutput && + !executionGraph.actions.some(action => !!action.jitCode) + ) { + printExecutionGraph(executionGraph, isJsonOutput); return; } if (argv[runTestsOptionName]) { - print(`Running ${compiledGraph.tests.length} unit tests...\n`); + logger.log(`Running ${compiledGraph.tests.length} unit tests...\n`); const testResults = await test(dbadapter, compiledGraph.tests); testResults.forEach(testResult => printTestResult(testResult)); if (testResults.some(testResult => !testResult.successful)) { printError("\nUnit tests did not pass; aborting run."); return 1; } - printSuccess("Unit tests completed successfully.\n"); + logger.success("Unit tests completed successfully.\n"); } let bigqueryOptions: {} = { @@ -623,17 +626,21 @@ export function runCli() { }); if (actionsByName.size === 0) { - print("No actions to run.\n"); + logger.log("No actions to run.\n"); return 0; } if (argv[dryRunOptionName]) { - print("Dry running (no changes to the warehouse will be applied)..."); + logger.log("Dry running (no changes to the warehouse will be applied)..."); } else { - print("Running...\n"); + logger.log("Running...\n"); } - const runner = run(dbadapter, executionGraph, { bigquery: bigqueryOptions }); + const runner = run( + dbadapter, + executionGraph, + { projectDir: argv[projectDirOption.name], bigquery: bigqueryOptions } + ); process.on("SIGINT", () => { runner.cancel(); }); @@ -660,9 +667,16 @@ export function runCli() { }); }; - runner.onChange(printExecutedGraph); + if (!isJsonOutput) { + runner.onChange(printExecutedGraph); + } const runResult = await runner.result(); - printExecutedGraph(runResult); + if (!isJsonOutput) { + printExecutedGraph(runResult); + } + if (isJsonOutput) { + print(prettyJsonStringify(runResult)); + } return runResult.status === dataform.RunResult.ExecutionStatus.SUCCESSFUL ? 0 : 1; } }, diff --git a/cli/index_run_e2e_test.ts b/cli/index_run_e2e_test.ts index ba675834a..42ee3e2b0 100644 --- a/cli/index_run_e2e_test.ts +++ b/cli/index_run_e2e_test.ts @@ -169,8 +169,10 @@ select 1 as \${dataform.projectConfig.vars.testVar2} } ], type: "table", + disabled: false } ], + jitData: {}, projectConfig: { assertionSchema: "dataform_assertions", defaultDatabase: DEFAULT_DATABASE, @@ -286,6 +288,7 @@ SELECT 1 as id } ], type: "table", + disabled: false }, { fileName: "definitions/test_assertion.sqlx", @@ -296,8 +299,10 @@ SELECT 1 as id schema: "dataform_assertions" }, type: "assertion", + disabled: true } ], + jitData: {}, projectConfig: { assertionSchema: "dataform_assertions", defaultDatabase: DEFAULT_DATABASE, diff --git a/cli/index_test_base.ts b/cli/index_test_base.ts index b47f53fe9..45b48a1d0 100644 --- a/cli/index_test_base.ts +++ b/cli/index_test_base.ts @@ -1,9 +1,62 @@ // tslint:disable tsr-detect-non-literal-fs-filename +import { execFile } from "child_process"; +import * as fs from "fs-extra"; +import { dump as dumpYaml, load as loadYaml } from "js-yaml"; import * as path from "path"; +import { version } from "df/core/version"; +import { dataform } from "df/protos/ts"; +import { corePackageTarPath, getProcessResult, nodePath, npmPath } from "df/testing"; +import { TmpDirFixture } from "df/testing/fixtures"; + export const DEFAULT_DATABASE = "dataform-open-source"; export const DEFAULT_LOCATION = "US"; export const DEFAULT_RESERVATION = "projects/dataform-open-source/locations/us/reservations/dataform-test"; export const CREDENTIALS_PATH = path.resolve(process.env.RUNFILES, "df/test_credentials/bigquery.json"); export const cliEntryPointPath = "cli/node_modules/@dataform/cli/bundle.js"; + +export async function setupJitProject( + tmpDirFixture: TmpDirFixture, + projectDir: string +): Promise { + const npmCacheDir = tmpDirFixture.createNewTmpDir(); + const packageJsonPath = path.join(projectDir, "package.json"); + + await getProcessResult( + execFile(nodePath, [cliEntryPointPath, "init", projectDir, DEFAULT_DATABASE, DEFAULT_LOCATION]) + ); + + const workflowSettingsPath = path.join(projectDir, "workflow_settings.yaml"); + const workflowSettings = dataform.WorkflowSettings.create( + loadYaml(fs.readFileSync(workflowSettingsPath, "utf8")) + ); + delete workflowSettings.dataformCoreVersion; + fs.writeFileSync(workflowSettingsPath, dumpYaml(workflowSettings)); + + fs.writeFileSync( + packageJsonPath, + `{ + "dependencies":{ + "@dataform/core": "${version}" + } +}` + ); + await getProcessResult( + execFile(npmPath, [ + "install", + "--prefix", + projectDir, + "--cache", + npmCacheDir, + corePackageTarPath + ]) + ); + + const jitTablePath = path.join(projectDir, "definitions", "jit_table.js"); + fs.ensureFileSync(jitTablePath); + fs.writeFileSync( + jitTablePath, + `publish("jit_table", {type: "table"}).jitCode(async (ctx) => { return "SELECT 1 as id"; })` + ); +} diff --git a/cli/tests/jit/index_jit_advanced_test.ts b/cli/tests/jit/index_jit_advanced_test.ts new file mode 100644 index 000000000..b32e4a4e5 --- /dev/null +++ b/cli/tests/jit/index_jit_advanced_test.ts @@ -0,0 +1,249 @@ +import { expect } from "chai"; +import { execFile } from "child_process"; +import * as fs from "fs-extra"; +import * as path from "path"; + +import { + cliEntryPointPath, + CREDENTIALS_PATH, + setupJitProject +} from "df/cli/index_test_base"; +import { getProcessResult, nodePath, suite, test } from "df/testing"; +import { TmpDirFixture } from "df/testing/fixtures"; + +suite("JiT support advanced", ({ afterEach }) => { + const tmpDirFixture = new TmpDirFixture(afterEach); + + test("JiT preOps and postOps support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const prePostPath = path.join(projectDir, "definitions", "pre_post_jit.js"); + fs.writeFileSync( + prePostPath, + `publish("pre_post_jit", { type: "table" }).jitCode(async (jctx) => { + return { + query: "SELECT 1 as id", + preOps: ["SELECT 'pre' as p"], + postOps: ["SELECT 'post' as p"] + }; + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=pre_post_jit" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + const prePostAction = executedGraph.actions.find((a: any) => a.target.name === "pre_post_jit"); + const statement = prePostAction.tasks[0].compiledSql; + expect(statement).to.include("SELECT 'pre' as p"); + expect(statement).to.include("SELECT 1 as id"); + expect(statement).to.include("SELECT 'post' as p"); + }); + + test("JiT incremental pre/post ops support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const incPrePostPath = path.join(projectDir, "definitions", "inc_pre_post_jit.js"); + fs.writeFileSync( + incPrePostPath, + `publish("inc_pre_post_jit", { type: "incremental" }).jitCode(async (jctx) => { + if (jctx.incremental()) { + return { + query: "SELECT 'inc_path_query' as q", + preOps: ["SELECT 'inc_path_pre' as p"] + }; + } else { + return { + query: "SELECT 'reg_path_query' as q", + preOps: ["SELECT 'reg_path_pre' as p"] + }; + } + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=inc_pre_post_jit", + "--full-refresh" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + const incAction = executedGraph.actions.find((a: any) => a.target.name === "inc_pre_post_jit"); + const statement = incAction.tasks[0].compiledSql; + expect(statement).to.include("SELECT 'reg_path_pre' as p"); + expect(statement).to.include("SELECT 'reg_path_query' as q"); + + // Also validate when not using full-refresh. + // Since the table doesn't exist, jctx.incremental() should still be false. + const runResultIncremental = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=inc_pre_post_jit" + ]) + ); + + expect(runResultIncremental.exitCode).equals(0); + const executedGraphInc = JSON.parse(runResultIncremental.stdout); + const incActionInc = executedGraphInc.actions.find((a: any) => a.target.name === "inc_pre_post_jit"); + const statementInc = incActionInc.tasks[0].compiledSql; + expect(statementInc).to.include("SELECT 'reg_path_pre' as p"); + expect(statementInc).to.include("SELECT 'reg_path_query' as q"); + }); + + test("JiT incremental mode validation with consecutive runs", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const incPath = path.join(projectDir, "definitions", "inc_jit.js"); + fs.writeFileSync( + incPath, + `publish("inc_jit", { type: "incremental" }).jitCode(async (jctx) => { + if (jctx.incremental()) { + return { + query: "SELECT 'inc_query' as q", + preOps: ["SELECT 'inc_pre' as p"] + }; + } else { + return { + query: "SELECT 'reg_query' as q", + preOps: ["SELECT 'reg_pre' as p"] + }; + } + })` + ); + + // 1. Initial run with full-refresh to create the table. + const firstRun = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--actions=inc_jit", + "--full-refresh" + ]) + ); + expect(firstRun.exitCode).equals(0); + + // 2. Second run without full-refresh. + // The table now exists, so it should use the incremental path. + const secondRun = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=inc_jit" + ]) + ); + + expect(secondRun.exitCode).equals(0); + const secondGraph = JSON.parse(secondRun.stdout); + const secondAction = secondGraph.actions.find((a: any) => a.target.name === "inc_jit"); + // Assert second run is INCREMENTAL + expect(secondAction.tasks[0].compiledSql).to.include("SELECT 'inc_pre' as p"); + expect(secondAction.tasks[0].compiledSql).to.include("SELECT 'inc_query' as q"); + }); + + test("JiT project-level data support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + fs.writeFileSync( + path.join(projectDir, "definitions", "project_data.js"), + "const { session } = require('@dataform/core');\nsession.jitData('app_secret', 'e2e_secret_value');" + ); + fs.writeFileSync( + path.join(projectDir, "definitions", "jit_data_test.js"), + `publish("jit_data_test", { type: "table" }).jitCode(async (jctx) => { + const secret = jctx.data.app_secret; + return "SELECT '" + secret + "' as val"; + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_data_test" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + const dataAction = executedGraph.actions.find((a: any) => a.target.name === "jit_data_test"); + expect(dataAction.tasks[0].compiledSql).to.include("e2e_secret_value"); + }); + + test("JiT complex session data support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + fs.writeFileSync( + path.join(projectDir, "definitions", "complex_project_data.js"), + "const { session } = require('@dataform/core');\n" + + "session.jitData('app_config', {\n" + + " env: 'test-env',\n" + + " version: 1.2,\n" + + " tags: ['t1', 't2']\n" + + "});" + ); + fs.writeFileSync( + path.join(projectDir, "definitions", "jit_complex_data_test.js"), + "publish('jit_complex_data_test', { type: 'table' }).jitCode(async (jctx) => {\n" + + " const config = jctx.data.app_config;\n" + + " return 'SELECT \\'' + config.env + '\\' as env, ' + config.version + ' as ver, \\'' + config.tags[0] + '\\' as tag';\n" + + "})" + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_complex_data_test" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + const dataAction = executedGraph.actions.find((a: any) => a.target.name === "jit_complex_data_test"); + expect(dataAction.tasks[0].compiledSql).to.include("SELECT 'test-env' as env, 1.2 as ver, 't1' as tag"); + }); +}); diff --git a/cli/tests/jit/index_jit_dependency_test.ts b/cli/tests/jit/index_jit_dependency_test.ts new file mode 100644 index 000000000..b3f1393d3 --- /dev/null +++ b/cli/tests/jit/index_jit_dependency_test.ts @@ -0,0 +1,93 @@ +import { expect } from "chai"; +import { execFile } from "child_process"; +import * as fs from "fs-extra"; +import * as path from "path"; + +import { + cliEntryPointPath, + CREDENTIALS_PATH, + setupJitProject +} from "df/cli/index_test_base"; +import { getProcessResult, nodePath, suite, test } from "df/testing"; +import { TmpDirFixture } from "df/testing/fixtures"; + +suite("JiT support dependencies", ({ afterEach }) => { + const tmpDirFixture = new TmpDirFixture(afterEach); + + test("JiT transitive dependency pruning", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + // A (AoT) -> B (JiT) + fs.writeFileSync( + path.join(projectDir, "definitions", "table_a.sqlx"), + "config { type: 'table' } SELECT 1 as val" + ); + fs.writeFileSync( + path.join(projectDir, "definitions", "table_b.js"), + `publish("table_b", { type: "table", dependencies: ["table_a"] }).jitCode(async (jctx) => { + const upstream = jctx.ref("table_a"); + return "SELECT '" + upstream + "' as ref_name"; + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=table_b", + "--include-deps" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + // Should have BOTH tables because of --include-deps + expect(executedGraph.actions.length).to.equal(2); + expect(executedGraph.actions.some((a: any) => a.target.name === "table_a")).to.equal(true); + const actionB = executedGraph.actions.find((a: any) => a.target.name === "table_b"); + expect(actionB).to.not.equal(undefined); + expect(actionB.tasks[0].compiledSql).to.include("SELECT '`dataform-open-source.dataform.table_a`' as ref_name"); + }); + + test("JiT to JiT dependency chain", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + // Action A (JiT) -> Action B (JiT) + fs.writeFileSync( + path.join(projectDir, "definitions", "jit_a.js"), + 'publish("jit_a", { type: "table" }).jitCode(async () => "SELECT 1 as val")' + ); + fs.writeFileSync( + path.join(projectDir, "definitions", "jit_b.js"), + "publish('jit_b', { type: 'table', dependencies: ['jit_a'] }).jitCode(async (jctx) => {\n" + + " const upstream = jctx.ref('jit_a');\n" + + " return 'SELECT \\'' + upstream + '\\' as ref_name';\n" + + "})" + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_b", + "--include-deps" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + expect(executedGraph.actions.length).to.equal(2); + const actionB = executedGraph.actions.find((a: any) => a.target.name === "jit_b"); + expect(actionB.tasks[0].compiledSql).to.include("SELECT '`dataform-open-source.dataform.jit_a`' as ref_name"); + }); +}); diff --git a/cli/tests/jit/index_jit_main_test.ts b/cli/tests/jit/index_jit_main_test.ts new file mode 100644 index 000000000..015581b38 --- /dev/null +++ b/cli/tests/jit/index_jit_main_test.ts @@ -0,0 +1,274 @@ +import { expect } from "chai"; +import { execFile } from "child_process"; +import * as fs from "fs-extra"; +import * as path from "path"; + +import { + cliEntryPointPath, + CREDENTIALS_PATH, + DEFAULT_DATABASE, + setupJitProject +} from "df/cli/index_test_base"; +import { getProcessResult, nodePath, suite, test } from "df/testing"; +import { TmpDirFixture } from "df/testing/fixtures"; + +suite("JiT support main", ({ afterEach }) => { + const tmpDirFixture = new TmpDirFixture(afterEach); + + test("compile command includes jitCode in output", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const compileResult = await getProcessResult( + execFile(nodePath, [cliEntryPointPath, "compile", projectDir, "--json"]) + ); + + expect(compileResult.exitCode).equals(0); + const compiledGraph = JSON.parse(compileResult.stdout); + const jitTable = compiledGraph.tables.find((t: any) => t.target.name === "jit_table"); + expect(!!jitTable).to.equal(true); + expect(jitTable.type).to.equal("table"); + expect(jitTable.jitCode).to.contain("async (ctx) => { return \"SELECT 1 as id\"; }"); + expect(compiledGraph).to.have.property("jitData"); + }); + + test("fails if both query and jitCode are provided", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const conflictPath = path.join(projectDir, "definitions", "conflict.js"); + fs.writeFileSync( + conflictPath, + `publish("conflict", {type: "table"}).query("SELECT 1").jitCode(async (ctx) => "SELECT 2")` + ); + + const compileResult = await getProcessResult( + execFile(nodePath, [cliEntryPointPath, "compile", projectDir, "--json"]) + ); + + expect(compileResult.exitCode).equals(1); + expect(compileResult.stderr).to.include("Cannot mix AoT and JiT compilation in action"); + }); + + test("run command performs JiT compilation during execution", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_table" + ]) + ); + + expect(runResult.exitCode).equals(0); + + const executedGraph = JSON.parse(runResult.stdout); + const jitAction = executedGraph.actions.find((a: any) => a.target.name === "jit_table"); + expect(!!jitAction).to.equal(true); + // Tasks array should be populated by the JiT runner + expect(jitAction.tasks.length).to.be.greaterThan(0); + expect(jitAction.tasks[0].compiledSql).to.include("SELECT 1 as id"); + }); + + test("mixed AoT and JiT support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const aotTablePath = path.join(projectDir, "definitions", "aot_table.sqlx"); + fs.writeFileSync(aotTablePath, "config { type: 'table' } SELECT 2 as id"); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json" + ]) + ); + + expect(runResult.exitCode).equals(0); + + const executedGraph = JSON.parse(runResult.stdout); + const aotAction = executedGraph.actions.find((a: any) => a.target.name === "aot_table"); + const jitAction = executedGraph.actions.find((a: any) => a.target.name === "jit_table"); + + expect(!!aotAction).to.equal(true); + expect(!!jitAction).to.equal(true); + expect(executedGraph.actions.length).to.equal(2); + + expect(aotAction.tasks[0].compiledSql).to.include("SELECT 2 as id"); + // JiT action should have its tasks populated dynamically + expect(jitAction.tasks.length).to.be.greaterThan(0); + expect(jitAction.tasks[0].compiledSql).to.include("SELECT 1 as id"); + }); + + test("JiT respects disabled flag", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + + const disabledPath = path.join(projectDir, "definitions", "disabled_jit.js"); + fs.writeFileSync( + disabledPath, + `publish("disabled_jit", { type: "table", disabled: true }).jitCode(async (jctx) => { + throw new Error("Should not be executed"); + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--actions=disabled_jit"], + { + env: { ...process.env, NO_COLOR: "1" } + } + ) + ); + + expect(runResult.exitCode).equals(0); + // When an action is disabled, it should print a "disabled" message. + expect(runResult.stdout).to.include("Dataset creation disabled: dataform.disabled_jit [table] [disabled]"); + }); + + test("JiT compilation failure reporting", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const failingJitPath = path.join(projectDir, "definitions", "failing_jit.js"); + fs.writeFileSync( + failingJitPath, + `publish("failing_jit", {type: "table"}).jitCode(async (ctx) => { throw new Error("JiT compilation failed!"); })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=failing_jit" + ]) + ); + + expect(runResult.exitCode).equals(1); + + const executedGraph = JSON.parse(runResult.stdout); + const failingAction = executedGraph.actions.find((a: any) => a.target.name === "failing_jit"); + + expect(!!failingAction).to.equal(true); + expect(failingAction.status).to.equal(3); // FAILED + expect(failingAction.tasks[0].status).to.equal(3); // FAILED + expect(failingAction.tasks[0].errorMessage).to.include("JiT compilation failed!"); + }); + + test("surfaces 'Table not found' RPC error during JiT compilation", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const rpcJitPath = path.join(projectDir, "definitions", "rpc_jit.js"); + fs.writeFileSync( + rpcJitPath, + `publish("rpc_jit", {type: "table"}).jitCode(async (jctx) => { + // This will fail because the table does not exist in the warehouse, + // and jctx.adapter.getTable throws an error in this case. + const table = await jctx.adapter.getTable({target: {database: "${DEFAULT_DATABASE}", schema: "sch", name: "tab"}}); + return "SELECT 1 as id"; + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=rpc_jit" + ]) + ); + + expect(runResult.exitCode).equals(1); + + const executedGraph = JSON.parse(runResult.stdout); + const rpcAction = executedGraph.actions.find((a: any) => a.target.name === "rpc_jit"); + + expect(!!rpcAction).to.equal(true); + expect(rpcAction.status).to.equal(3); + expect(rpcAction.tasks[0].status).to.equal(3); + expect(rpcAction.tasks[0].errorMessage).to.include("JiT compilation error"); + expect(rpcAction.tasks[0].errorMessage).to.include("Table not found"); + expect(rpcAction.tasks[0].errorMessage).to.include(DEFAULT_DATABASE); + expect(rpcAction.tasks[0].errorMessage).to.include('"schema":"sch"'); + expect(rpcAction.tasks[0].errorMessage).to.include('"name":"tab"'); + }); + + test("mixed support with AoT filtered out", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const aotTablePath = path.join(projectDir, "definitions", "aot_table.sqlx"); + fs.writeFileSync(aotTablePath, "config { type: 'table' } SELECT 2 as id"); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_table" + ]) + ); + + expect(runResult.exitCode).equals(0); + + const executedGraph = JSON.parse(runResult.stdout); + expect(executedGraph.actions.length).to.equal(1); + const jitAction = executedGraph.actions.find((a: any) => a.target.name === "jit_table"); + expect(!!jitAction).to.equal(true); + expect(jitAction.tasks.length).to.be.greaterThan(0); + expect(jitAction.tasks[0].compiledSql).to.include("SELECT 1 as id"); + }); + + test("mixed support with JiT filtered out", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + const aotTablePath = path.join(projectDir, "definitions", "aot_table.sqlx"); + fs.writeFileSync(aotTablePath, "config { type: 'table' } SELECT 2 as id"); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=aot_table" + ]) + ); + + expect(runResult.exitCode).equals(0); + + const executedGraph = JSON.parse(runResult.stdout); + expect(executedGraph.actions.length).to.equal(1); + const aotAction = executedGraph.actions.find((a: any) => a.target.name === "aot_table"); + expect(!!aotAction).to.equal(true); + expect(aotAction.tasks[0].statement).to.include("SELECT 2 as id"); + }); +}); diff --git a/cli/tests/jit/index_jit_runtime_test.ts b/cli/tests/jit/index_jit_runtime_test.ts new file mode 100644 index 000000000..0af1eae01 --- /dev/null +++ b/cli/tests/jit/index_jit_runtime_test.ts @@ -0,0 +1,145 @@ +import { expect } from "chai"; +import { execFile } from "child_process"; +import * as fs from "fs-extra"; +import * as path from "path"; + +import { + cliEntryPointPath, + CREDENTIALS_PATH, + setupJitProject +} from "df/cli/index_test_base"; +import { getProcessResult, nodePath, suite, test } from "df/testing"; +import { TmpDirFixture } from "df/testing/fixtures"; + +suite("JiT support runtime", ({ afterEach }) => { + const tmpDirFixture = new TmpDirFixture(afterEach); + + test("JiT require() of local files support", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + // Add a helper JS file + fs.ensureDirSync(path.join(projectDir, "helpers")); + fs.writeFileSync( + path.join(projectDir, "helpers", "utils.js"), + "module.exports = { getValue: () => 'required_value' };" + ); + // Add a JiT table that requires it + fs.writeFileSync( + path.join(projectDir, "definitions", "jit_require_test.js"), + `publish("jit_require_test", { type: "table" }).jitCode(async (jctx) => { + const utils = require("../helpers/utils.js"); + return "SELECT '" + utils.getValue() + "' as val"; + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=jit_require_test" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + const reqAction = executedGraph.actions.find((a: any) => a.target.name === "jit_require_test"); + expect(reqAction.tasks[0].compiledSql).to.include("required_value"); + }); + + test("JiT worker timeout handling", { timeout: 15000 }, async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + + // Add a JiT table that hangs in an infinite loop + const hangPath = path.join(projectDir, "definitions", "hang_jit.js"); + fs.writeFileSync( + hangPath, + `publish("hang_jit", { type: "table" }).jitCode(async (jctx) => { + while(true) { /* loop */ } + return "SELECT 1"; + })` + ); + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=hang_jit", + "--timeout=4s" + ], { timeout: 20000 }) + ); + + expect(runResult.exitCode).equals(1); + expect(runResult.stdout).to.include("Worker timed out"); + }); + + test("JiT parallel execution robustness", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + // Add multiple JiT tables + for (let i = 0; i < 5; i++) { + fs.writeFileSync( + path.join(projectDir, "definitions", `jit_${i}.js`), + `publish("jit_${i}", { type: "table" }).jitCode(async (jctx) => "SELECT ${i} as val")` + ); + } + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json" + ]) + ); + + expect(runResult.exitCode).equals(0); + const executedGraph = JSON.parse(runResult.stdout); + expect(executedGraph.actions.filter((a: any) => a.target.name.startsWith("jit_")).length).to.equal(6); // jit_table + 5 others + }); + + test("JiT handles hard worker crash", async () => { + const projectDir = tmpDirFixture.createNewTmpDir(); + await setupJitProject(tmpDirFixture, projectDir); + // Add a JiT table that crashes the process + const crashPath = path.join(projectDir, "definitions", "crash_jit.js"); + fs.writeFileSync( + crashPath, + `publish("crash_jit", { type: "table" }).jitCode(async (jctx) => { + setTimeout(() => { throw new Error("Hard crash"); }, 10); + return new Promise(() => {}); // Hang until crash + })` + ); + + const runResult = await getProcessResult( + execFile(nodePath, [ + cliEntryPointPath, + "run", + projectDir, + "--credentials", + CREDENTIALS_PATH, + "--dry-run", + "--json", + "--actions=crash_jit" + ]) + ); + + expect(runResult.exitCode).equals(1); + const executedGraph = JSON.parse(runResult.stdout); + const crashAction = executedGraph.actions.find((a: any) => a.target.name === "crash_jit"); + expect(crashAction.status).to.equal(3); // FAILED + expect(crashAction.tasks[0].errorMessage).to.include("Worker exited with code 1"); + }); +}); diff --git a/cli/tests/jit/jit_build_test.ts b/cli/tests/jit/jit_build_test.ts new file mode 100644 index 000000000..9798eed4f --- /dev/null +++ b/cli/tests/jit/jit_build_test.ts @@ -0,0 +1,40 @@ +import { expect } from "chai"; + +import { Builder } from "df/cli/api/commands/build"; +import { dataform } from "df/protos/ts"; +import { suite, test } from "df/testing"; + +suite("build", () => { + test("jit_code is preserved in ExecutionAction", () => { + const compiledGraph = dataform.CompiledGraph.create({ + projectConfig: { warehouse: "bigquery" }, + tables: [ + { + target: { database: "db", schema: "schema", name: "table" }, + jitCode: "console.log('jit table')", + enumType: dataform.TableType.TABLE + } + ], + operations: [ + { + target: { database: "db", schema: "schema", name: "operation" }, + jitCode: "console.log('jit operation')", + queries: [] + } + ] + }); + + const builder = new Builder(compiledGraph, {}, { tables: [] }); + const executionGraph = builder.build(); + + const tableAction = executionGraph.actions.find( + (a: dataform.IExecutionAction) => a.target.name === "table" + ); + expect(tableAction.jitCode).equals("console.log('jit table')"); + + const operationAction = executionGraph.actions.find( + (a: dataform.IExecutionAction) => a.target.name === "operation" + ); + expect(operationAction.jitCode).equals("console.log('jit operation')"); + }); +}); diff --git a/cli/tests/jit/jit_run_test.ts b/cli/tests/jit/jit_run_test.ts new file mode 100644 index 000000000..1656dd6c8 --- /dev/null +++ b/cli/tests/jit/jit_run_test.ts @@ -0,0 +1,416 @@ +import { expect } from "chai"; +import { anything, capture, instance, mock, verify, when } from "ts-mockito"; + +import { handleDbRequest as handleRpc } from "df/cli/api/commands/jit/rpc"; +import { Runner } from "df/cli/api/commands/run"; +import { IDbAdapter, IDbClient } from "df/cli/api/dbadapters"; +import { jitCompile } from "df/core/jit_compiler"; +import { dataform } from "df/protos/ts"; +import { suite, test } from "df/testing"; + +suite("run", () => { + test("JiT compilation is performed for Table actions", async () => { + const { mockAdapter, adapterInstance } = createMocks(); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "jit_table" }, + type: "table", + tableType: "table", + jitCode: "async (jctx) => { return 'SELECT 1'; }", + tasks: [] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + // RPC callback bridge for tests + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + // Verify overall run status + if (result.status !== dataform.RunResult.ExecutionStatus.SUCCESSFUL) { + process.stderr.write("Run failed with actions: " + JSON.stringify(result.actions, null, 2) + "\n"); + } + expect(result.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + + // Verify action result + const actionResult = result.actions[0]; + expect(actionResult.target.name).equals("jit_table"); + expect(actionResult.status).equals(dataform.ActionResult.ExecutionStatus.SUCCESSFUL); + + // Verify task results + expect(actionResult.tasks.length).equals(1); + expect(actionResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.SUCCESSFUL); + + // Verify that the Runner executed the query statement returned by JiT compilation + verify(mockAdapter.execute(anything(), anything())).atLeast(1); + }); + + test("JiT compilation is performed for Operation actions", async () => { + const { mockAdapter, adapterInstance } = createMocks(); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "jit_op" }, + type: "operation", + jitCode: "async (jctx) => { return ['SELECT 1', 'SELECT 2']; }", + tasks: [] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + // RPC callback bridge for tests + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + + const actionResult = result.actions[0]; + expect(actionResult.status).equals(dataform.ActionResult.ExecutionStatus.SUCCESSFUL); + expect(actionResult.tasks.length).equals(2); + expect(actionResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.SUCCESSFUL); + expect(actionResult.tasks[1].status).equals(dataform.TaskResult.ExecutionStatus.SUCCESSFUL); + + verify(mockAdapter.execute("SELECT 1", anything())).once(); + verify(mockAdapter.execute("SELECT 2", anything())).once(); + }); + + test("Mixed run with JiT and AoT actions", async () => { + const { mockAdapter, adapterInstance } = createMocks(); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "aot_table" }, + type: "table", + tableType: "table", + tasks: [dataform.ExecutionTask.create({ statement: "SELECT 'aot'", type: "statement" })] + }, + { + target: { database: "db", schema: "sch", name: "jit_table" }, + type: "table", + tableType: "table", + jitCode: "async (jctx) => { return 'SELECT \"jit\"'; }", + tasks: [], + dependencyTargets: [{ database: "db", schema: "sch", name: "aot_table" }] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + // RPC callback bridge for tests + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + expect(result.actions.length).equals(2); + + const aotResult = result.actions.find((a: dataform.IActionResult) => a.target.name === "aot_table"); + const jitResult = result.actions.find((a: dataform.IActionResult) => a.target.name === "jit_table"); + + expect(aotResult.status).equals(dataform.ActionResult.ExecutionStatus.SUCCESSFUL); + expect(aotResult.tasks.length).equals(1); + expect(aotResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.SUCCESSFUL); + + expect(jitResult.status).equals(dataform.ActionResult.ExecutionStatus.SUCCESSFUL); + expect(jitResult.tasks.length).equals(1); + expect(jitResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.SUCCESSFUL); + + // Verify that both actions resulted in database execution calls + verify(mockAdapter.execute(anything(), anything())).atLeast(2); + const [firstStatement] = capture(mockAdapter.execute).first(); + const [secondStatement] = capture(mockAdapter.execute).second(); + const allStatements = [firstStatement, secondStatement]; + expect(allStatements.some((s: string) => s.includes("SELECT 'aot'"))).to.equal(true); + expect(allStatements.some((s: string) => s.includes("SELECT \"jit\""))).to.equal(true); + }); + + test("Handles JiT compilation syntax error", async () => { + const { adapterInstance } = createMocks(); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "bad_jit" }, + type: "table", + tableType: "table", + jitCode: "async (jctx) => { return syntax error; }", + tasks: [] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + // RPC callback bridge for tests + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.FAILED); + + const actionResult = result.actions[0]; + expect(actionResult.status).equals(dataform.ActionResult.ExecutionStatus.FAILED); + expect(actionResult.tasks.length).equals(1); + expect(actionResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.FAILED); + expect(actionResult.tasks[0].errorMessage).to.contain("JiT compilation error"); + }); + + test("Handles database error during JiT compilation (RPC failure)", async () => { + const { adapterInstance } = createMocks(); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "jit_db_error" }, + type: "table", + tableType: "table", + // This code calls jctx.adapter.execute() which triggers our mockClient.execute + jitCode: "async (jctx) => { await jctx.adapter.execute({statement: 'SELECT fail'}); return 'SELECT 2'; }", + tasks: [] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + // RPC callback bridge for tests + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.FAILED); + + const actionResult = result.actions[0]; + expect(actionResult.status).equals(dataform.ActionResult.ExecutionStatus.FAILED); + expect(actionResult.tasks.length).equals(1); + expect(actionResult.tasks[0].status).equals(dataform.TaskResult.ExecutionStatus.FAILED); + expect(actionResult.tasks[0].errorMessage).to.contain("RPC DB Fail"); + }); + + test("Handles JiT incremental table compilation", async () => { + const target = { database: "db", schema: "sch", name: "incremental_jit" }; + const executionGraph = createGraph([ + { + target, + type: "table", + tableType: "incremental", + jitCode: `async (jctx) => { + return jctx.incremental() ? "SELECT 'inc' as t" : "SELECT 'full' as t"; + }`, + tasks: [] + } + ]); + + let runner: Runner; + + // 1. First run - empty warehouse, should use 'full' path + const { mockAdapter: mockAdapterFull, adapterInstance: adapterInstanceFull } = createMocks(); + runner = new Runner(adapterInstanceFull, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const fullResult = await runner.execute().result(); + expect(fullResult.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + + verify(mockAdapterFull.execute(anything(), anything())).atLeast(1); + const [executedSqlFull] = capture(mockAdapterFull.execute).last(); + expect(executedSqlFull).to.contain("create or replace table `db.sch.incremental_jit` as"); + expect(executedSqlFull).to.contain("SELECT 'full' as t"); + + // 2. Mock that the table now exists in the warehouse + executionGraph.warehouseState.tables.push({ + target, + type: dataform.TableMetadata.Type.TABLE, + fields: [{ name: "t" }] + }); + + // 3. Second run - table exists, should use 'incremental' path + const { + mockAdapter: mockAdapterIncremental, + adapterInstance: adapterInstanceIncremental + } = createMocks(); + runner = new Runner(adapterInstanceIncremental, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const incrementalResult = await runner.execute().result(); + expect(incrementalResult.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + + verify(mockAdapterIncremental.execute(anything(), anything())).atLeast(1); + const [executedSqlIncremental] = capture(mockAdapterIncremental.execute).last(); + expect(executedSqlIncremental).to.contain("SELECT 'inc' as t"); + }); + + test("Handles JiT incremental table compilation - incremental mode", async () => { + const { mockAdapter, adapterInstance } = createMocks(); + + const target = { database: "db", schema: "sch", name: "incremental_jit" }; + const executionGraph = createGraph([ + { + target, + type: "table", + tableType: "incremental", + jitCode: `async (jctx) => { + return jctx.incremental() ? "SELECT 'inc' as t" : "SELECT 'full' as t"; + }`, + tasks: [] + } + ]); + // Mock that the table already exists in the warehouse as a TABLE with a 't' field + executionGraph.warehouseState.tables.push({ + target, + type: dataform.TableMetadata.Type.TABLE, + fields: [{ name: "t" }] + }); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + + // Verify it used the 'incremental' query path + verify(mockAdapter.execute(anything(), anything())).atLeast(1); + const [executedSql] = capture(mockAdapter.execute).last(); + // For BigQuery, it should be an 'insert into' because no uniqueKey was specified. + // We check for substrings without trailing spaces to avoid exact whitespace mismatches. + // tslint:disable: tsr-detect-sql-literal-injection + expect(executedSql).to.equal( + "insert into `db.sch.incremental_jit` \n" + + "(`t`) \n" + + "select `t` \n" + + "from (SELECT 'inc' as t) as insertions" + ); + // tslint:enable: tsr-detect-sql-literal-injection + }); + + test("JiT compilation with RPC calls (ListTables, GetTable, DeleteTable)", async () => { + const { mockAdapter, adapterInstance } = createMocks(); + + const target = { database: "db", schema: "sch", name: "existing_table" }; + when(mockAdapter.tables(anything(), anything())).thenResolve([{ target }]); + when(mockAdapter.table(anything())).thenResolve({ + target, + type: dataform.TableMetadata.Type.TABLE + } as any); + + const executionGraph = createGraph([ + { + target: { database: "db", schema: "sch", name: "jit_rpc_test" }, + type: "table", + tableType: "table", + jitCode: `async (jctx) => { + const list = await jctx.adapter.listTables({ database: "db", schema: "sch" }); + const table = await jctx.adapter.getTable({ target: list.tables[0].target }); + await jctx.adapter.deleteTable({ target: table.target }); + return "SELECT '" + table.target.name + "' as deleted_table"; + }`, + tasks: [] + } + ]); + + const runner = new Runner(adapterInstance, executionGraph, { + jitCompiler: async (req, pdir, adapter) => { + return await jitCompile(req, (method, internalReq, callback) => { + (adapter as any).rpcImpl(method, internalReq, callback); + }); + } + }); + const result = await runner.execute().result(); + + expect(result.status).equals(dataform.RunResult.ExecutionStatus.SUCCESSFUL); + const actionResult = result.actions[0]; + expect(actionResult.status).equals(dataform.ActionResult.ExecutionStatus.SUCCESSFUL); + + verify(mockAdapter.deleteTable(anything())).once(); + const [deletedTarget] = capture(mockAdapter.deleteTable).last(); + expect(deletedTarget.name).equals("existing_table"); + + verify(mockAdapter.execute(anything(), anything())).once(); + const [executedSql] = capture(mockAdapter.execute).last(); + expect(executedSql).to.contain("SELECT 'existing_table' as deleted_table"); + }); +}); + +function createMocks() { + const mockAdapter = mock(); + const mockClient = mock(); + + when(mockAdapter.schemas(anything())).thenResolve([]); + when(mockAdapter.execute(anything(), anything())).thenCall((statement: string) => { + if (statement.includes("fail") || statement.includes("nonexistent")) { + throw new Error("RPC DB Fail"); + } + return Promise.resolve({ + rows: [], + metadata: {} + }); + }); + when(mockClient.executeRaw(anything(), anything())).thenCall((statement: string) => { + if (statement.includes("fail") || statement.includes("nonexistent")) { + throw new Error("RPC DB Fail"); + } + return Promise.resolve({ + rows: [], + metadata: {} + }); + }); + when(mockClient.execute(anything(), anything())).thenCall((statement: string) => { + if (statement.includes("fail") || statement.includes("nonexistent")) { + throw new Error("RPC DB Fail"); + } + return Promise.resolve({ + rows: [], + metadata: {} + }); + }); + + const adapterInstance = instance(mockAdapter); + (adapterInstance as any).rpcImpl = (method: string, req: Uint8Array, callback: any) => { + handleRpc(instance(mockAdapter), instance(mockClient), method, req) + .then((res: Uint8Array) => callback(null, res)) + .catch((err: Error) => callback(err, null)); + }; + + return { mockAdapter, mockClient, adapterInstance }; +} + +function createGraph(actions: any[]): dataform.ExecutionGraph { + return dataform.ExecutionGraph.create({ + projectConfig: { warehouse: "bigquery" }, + runConfig: { fullRefresh: false, timeoutMillis: 30000 }, + warehouseState: { tables: [] }, + actions: actions.map(a => ({ + dependencyTargets: [], + ...a + })) + }); +} diff --git a/cli/vm/BUILD b/cli/vm/BUILD index 69d30e808..4adc5f391 100644 --- a/cli/vm/BUILD +++ b/cli/vm/BUILD @@ -4,7 +4,10 @@ load("//tools:ts_library.bzl", "ts_library") ts_library( name = "vm", - srcs = ["compile.ts"], + srcs = [ + "compile.ts", + "jit_worker.ts", + ], deps = [ "//common/protos", "//core", @@ -23,6 +26,7 @@ ts_library( srcs = [], data = [ ":compile_loader.js", + ":jit_loader.js", ], deps = [ ":vm", @@ -50,3 +54,16 @@ nodejs_binary( "--bazel_patch_module_resolver", ], ) + +nodejs_binary( + name = "jit_worker_bin", + data = [ + ":vm", + "@npm//source-map-support", + ], + entry_point = ":jit_worker.ts", + templated_args = [ + "--node_options=--require=source-map-support/register", + "--bazel_patch_module_resolver", + ], +) diff --git a/cli/vm/compile.ts b/cli/vm/compile.ts index f5e8cfc1e..526198373 100644 --- a/cli/vm/compile.ts +++ b/cli/vm/compile.ts @@ -82,6 +82,9 @@ export function listenForCompileRequest() { } if (require.main === module) { + if (process.send) { + process.send({ type: "worker_booted" }); + } listenForCompileRequest(); } diff --git a/cli/vm/jit_loader.js b/cli/vm/jit_loader.js new file mode 100644 index 000000000..cb325cee9 --- /dev/null +++ b/cli/vm/jit_loader.js @@ -0,0 +1,12 @@ +'use strict'; + +if (require.main === module) { + var entryPointPath = 'df/cli/vm/jit_worker.js'; + var mainScript = process.argv[1] = entryPointPath; + try { + module.constructor._load(mainScript, this, /*isMain=*/true); + } catch (e) { + console.error(e.stack || e); + process.exit(1); + } +} diff --git a/cli/vm/jit_worker.ts b/cli/vm/jit_worker.ts new file mode 100644 index 000000000..1d3bca0ef --- /dev/null +++ b/cli/vm/jit_worker.ts @@ -0,0 +1,132 @@ +import * as fs from "fs"; +import * as path from "path"; +import { NodeVM } from "vm2"; + +import { dataform } from "df/protos/ts"; + +const pendingRpcCallbacks = new Map void>(); + +// Guard against double-initialization in some environments (e.g. Bazel) +const globalObj = global as any; +if (!globalObj._dataform_jit_worker_initialized) { + globalObj._dataform_jit_worker_initialized = true; + globalObj._has_started_processing = false; + + process.on("message", (res: any) => { + if (res.type === "rpc_response") { + const callback = pendingRpcCallbacks.get(res.correlationId); + if (callback) { + pendingRpcCallbacks.delete(res.correlationId); + callback(res.error || null, res.response ? Buffer.from(res.response) : null); + } + } + }); + + if (require.main === module || globalObj._dataform_jit_worker_force_main) { + if (process.send) { + process.send({ type: "worker_booted" }); + } + process.on("message", async (message: any) => { + if (message.type === "jit_compile") { + if (globalObj._has_started_processing) { + process.send({ + type: "jit_error", + error: "Worker process received multiple JiT compilation requests. Subsequent requests are rejected." + }); + return; + } + globalObj._has_started_processing = true; + await handleJitRequest(message); + } + }); + } +} + +export async function handleJitRequest(message: { + request: any; + projectDir: string; +}) { + try { + const { request, projectDir } = message; + + const projectLocalCorePath = path.join(projectDir, "node_modules", "@dataform", "core", "bundle.js"); + const hasProjectLocalCore = fs.existsSync(projectLocalCorePath); + + if (!hasProjectLocalCore && !fs.existsSync(path.join(projectDir, "node_modules", "@dataform", "core", "package.json"))) { + throw new Error( + "Could not find a recent installed version of @dataform/core in the project. Check that " + + "either `dataformCoreVersion` is specified in `workflow_settings.yaml`, or " + + "`@dataform/core` is specified in `package.json`. If using `package.json`, then run " + + "`dataform install`." + ); + } + + const rpcCallback = (method: string, reqBytes: Uint8Array, callback: (err: string | null, resBytes: Uint8Array | null) => void) => { + const correlationId = Math.random().toString(36).substring(7); + pendingRpcCallbacks.set(correlationId, callback); + + process.send({ + type: "rpc_request", + method, + request: reqBytes, + correlationId + }); + }; + + const requestMessage = dataform.JitCompilationRequest.fromObject(request); + const requestBytes = Array.from(dataform.JitCompilationRequest.encode(requestMessage).finish()); + + // Use the action's file name as the VM filename for correct relative requires. + const vmFileName = requestMessage.fileName + ? path.resolve(projectDir, requestMessage.fileName) + : path.resolve(projectDir, "index.js"); + + const vm = new NodeVM({ + env: process.env, + require: { + builtin: ["path", "fs"], + context: "sandbox", + external: true, + root: projectDir, + mock: hasProjectLocalCore ? {} : { + "@dataform/core": require("@dataform/core") + }, + resolve: (moduleName, parentDirName) => { + if (moduleName.startsWith(".")) { + return path.resolve(parentDirName, moduleName); + } + return moduleName; + } + }, + sourceExtensions: ["js", "json"] + }); + + const jitCompileInVm = vm.run(` + const { jitCompiler } = require("@dataform/core"); + + global.require = require; + + module.exports = async (requestBytes, armoredRpcCallback) => { + const requestBytesTyped = new Uint8Array(requestBytes); + const internalRpcCallback = (method, reqBytes, callback) => { + armoredRpcCallback(method, Array.from(reqBytes), (errStr, resBytes) => { + if (errStr) { + return callback(new Error(errStr)); + } + callback(null, resBytes ? new Uint8Array(resBytes) : new Uint8Array()); + }); + }; + + const compilerInstance = jitCompiler(internalRpcCallback); + return await compilerInstance.compile(requestBytesTyped); + }; + `, vmFileName); + + const responseBytes = await jitCompileInVm(requestBytes, rpcCallback); + const response = dataform.JitCompilationResponse.decode(new Uint8Array(responseBytes as number[])); + + process.send({ type: "jit_response", response: response.toJSON() }); + } catch (e) { + process.send({ type: "jit_error", error: e.stack || e.message }); + } +} diff --git a/common/protos/BUILD b/common/protos/BUILD index 87efbbe1c..6f12d80ed 100644 --- a/common/protos/BUILD +++ b/common/protos/BUILD @@ -10,6 +10,7 @@ ts_library( deps = [ "//:modules-fix", "//common/strings", + "//protos:ts", "@npm//protobufjs", ], ) diff --git a/common/protos/structs.ts b/common/protos/structs.ts new file mode 100644 index 000000000..b937ed517 --- /dev/null +++ b/common/protos/structs.ts @@ -0,0 +1,70 @@ +import { google } from "df/protos/ts"; + +export class Structs { + public static toObject(struct?: google.protobuf.IStruct): { [key: string]: any } | undefined { + if (!struct || !struct.fields) { + return undefined; + } + const result: { [key: string]: any } = {}; + for (const [key, value] of Object.entries(struct.fields)) { + result[key] = this.fromValue(value); + } + return result; + } + + public static fromObject(obj: { [key: string]: any }): google.protobuf.IStruct { + const fields: { [key: string]: google.protobuf.IValue } = {}; + for (const [key, val] of Object.entries(obj)) { + fields[key] = this.toValue(val); + } + return { fields }; + } + + private static fromValue(value: google.protobuf.IValue): any { + if (value.nullValue !== null && value.nullValue !== undefined) { + return null; + } + if (value.numberValue !== null && value.numberValue !== undefined) { + return value.numberValue; + } + if (value.stringValue !== null && value.stringValue !== undefined) { + return value.stringValue; + } + if (value.boolValue !== null && value.boolValue !== undefined) { + return value.boolValue; + } + if (value.structValue !== null && value.structValue !== undefined) { + return this.toObject(value.structValue); + } + if (value.listValue !== null && value.listValue !== undefined) { + return (value.listValue.values || []).map((v: any) => this.fromValue(v)); + } + return undefined; + } + + private static toValue(val: any): google.protobuf.IValue { + if (typeof val === "number") { + return { numberValue: val }; + } + if (typeof val === "string") { + return { stringValue: val }; + } + if (typeof val === "boolean") { + return { boolValue: val }; + } + if (val === null || val === undefined) { + return { nullValue: 0 }; + } + if (Array.isArray(val)) { + return { + listValue: { + values: val.map(v => this.toValue(v)) + } + }; + } + if (typeof val === "object") { + return { structValue: this.fromObject(val) }; + } + return { nullValue: 0 }; + } +} diff --git a/core/jit_context.ts b/core/jit_context.ts index d7e994836..ed0432bde 100644 --- a/core/jit_context.ts +++ b/core/jit_context.ts @@ -1,3 +1,4 @@ +import { Structs } from "df/common/protos/structs"; import { IActionContext, ITableContext, JitContext, Resolvable } from "df/core/contextables"; import { ambiguousActionNameMsg, resolvableAsTarget, ResolvableMap, stringifyResolvable, toResolvable } from "df/core/utils"; import { dataform, google } from "df/protos/ts"; @@ -24,7 +25,7 @@ export class SqlActionJitContext implements JitContext { actionTarget: dep, value: canonicalTargetValue(dep) }))); - this.data = jitDataToJsValue(request.jitData); + this.data = Structs.toObject(request.jitData); } public self(): string { @@ -103,46 +104,3 @@ export class IncrementalTableJitContext extends TableJitContext { } } -function jitDataToJsValue(value?: google.protobuf.IStruct): { [key: string]: {} } | undefined { - if (value === undefined || value === null) { - return - } - function protobufValueToJs(val: google.protobuf.IValue): {} { - if (val.nullValue != null) { - return null; - } - if (val.stringValue != null) { - return val.stringValue; - } - if (val.numberValue != null) { - return val.numberValue; - } - if (val.boolValue != null) { - return val.boolValue; - } - if (val.listValue != null) { - return (val.listValue.values || []).map(protobufValueToJs); - } - if (val.structValue != null) { - return Object.fromEntries( - Object.entries(val.structValue.fields || {}).map( - ([fieldKey, fieldValue]) => ([ - fieldKey, - protobufValueToJs(fieldValue) - ]) - ) - ); - } - - throw new Error(`Unsupported protobuf value: ${JSON.stringify(val)}`); - } - - return Object.fromEntries( - Object.entries(value.fields || {}).map( - ([fieldKey, fieldValue]) => [ - fieldKey, - protobufValueToJs(fieldValue) - ] - ) - ); -} diff --git a/packages/@dataform/cli/BUILD b/packages/@dataform/cli/BUILD index 7a1d7085c..cb5f9f4b2 100644 --- a/packages/@dataform/cli/BUILD +++ b/packages/@dataform/cli/BUILD @@ -10,7 +10,10 @@ ts_library( srcs = glob(["*.ts"]), deps = [ "//cli", + "//cli/api", + "//cli/vm", "//cli/vm:compile_loader", + "//protos:ts", ], ) diff --git a/packages/@dataform/cli/worker.ts b/packages/@dataform/cli/worker.ts index bc5f29306..4f279061f 100644 --- a/packages/@dataform/cli/worker.ts +++ b/packages/@dataform/cli/worker.ts @@ -1,2 +1,25 @@ -import { listenForCompileRequest } from "df/cli/vm/compile"; -listenForCompileRequest(); +import { compile } from "df/cli/vm/compile"; +import { handleJitRequest } from "df/cli/vm/jit_worker"; + +process.on("message", async (message: any) => { + try { + if (message.type === "jit_compile") { + await handleJitRequest(message); + } else { + // It's an AoT compile request + const responseBase64 = compile(message); + process.send(responseBase64); + } + } catch (e) { + process.send({ + error: e.message || String(e), + stack: e.stack, + name: e.name + }); + } +}); + +// Signal that the worker is alive and listening +if (process.send) { + process.send({ type: "worker_booted" }); +} diff --git a/protos/execution.proto b/protos/execution.proto index 36725e823..8379cefed 100644 --- a/protos/execution.proto +++ b/protos/execution.proto @@ -46,6 +46,8 @@ message ExecutionAction { string jit_code = 12; + bool disabled = 13; + reserved 1, 3, 7; } @@ -118,6 +120,7 @@ message TaskResult { string error_message = 2; Timing timing = 3; ExecutionMetadata metadata = 4; + string compiled_sql = 5; } message TestResult { diff --git a/protos/jit.proto b/protos/jit.proto index 999a98e92..7adeb5b59 100644 --- a/protos/jit.proto +++ b/protos/jit.proto @@ -34,6 +34,8 @@ message ExecuteRequest { bool fetch_results = 4; BigQueryExecuteOptions big_query_options = 5; + // Query parameters for parameterized queries. + google.protobuf.Struct params = 6; } message BigQueryExecuteOptions { @@ -49,6 +51,8 @@ message BigQueryExecuteOptions { string job_prefix = 5; // Is dry run job. bool dry_run = 6; + // BigQuery reservation to use for the job. + string reservation = 7; } // Synchronous execution response result. @@ -109,6 +113,8 @@ message JitCompilationRequest { JitCompilationTargetType compilation_target_type = 5; // List of additional file paths accessible at compilation. repeated string file_paths = 6; + // File name where the target is defined. + string file_name = 7; } // JiT compilation response. diff --git a/testing/hook.ts b/testing/hook.ts index cc342ffe9..d132bb007 100644 --- a/testing/hook.ts +++ b/testing/hook.ts @@ -20,7 +20,7 @@ export function hook( export type IHookHandler = typeof Hook.create; export class Hook { - public static readonly DEFAULT_TIMEOUT_MILLIS = 30000; + public static readonly DEFAULT_TIMEOUT_MILLIS = 300000; public static create( nameOrOptions: IHookOptions | string, diff --git a/testing/test.ts b/testing/test.ts index 6d3ea5a4a..78ba52ece 100644 --- a/testing/test.ts +++ b/testing/test.ts @@ -26,7 +26,7 @@ export function test( } export class Test { - public static readonly DEFAULT_TIMEOUT_MILLIS = 30000; + public static readonly DEFAULT_TIMEOUT_MILLIS = 300000; public static create( nameOrOptions: ITestOptions | string, diff --git a/tests/api/api.spec.ts b/tests/api/api.spec.ts index b29936140..d42b3ae60 100644 --- a/tests/api/api.spec.ts +++ b/tests/api/api.spec.ts @@ -973,8 +973,6 @@ suite("@dataform/api", () => { ).thenReject(new Error("bad statement")); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const runner = new Runner(mockDbAdapterInstance, RUN_TEST_GRAPH); @@ -1018,8 +1016,6 @@ suite("@dataform/api", () => { ).thenReject(new Error("bad statement")); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); let runner = new Runner(mockDbAdapterInstance, RUN_TEST_GRAPH); runner.execute(); @@ -1041,7 +1037,7 @@ suite("@dataform/api", () => { }).toJSON() ); - runner = new Runner(mockDbAdapterInstance, RUN_TEST_GRAPH, undefined, result); + runner = new Runner(mockDbAdapterInstance, RUN_TEST_GRAPH, result); expect( dataform.RunResult.create(cleanTiming(await runner.execute().result())).toJSON() @@ -1079,8 +1075,6 @@ suite("@dataform/api", () => { .thenResolve({ rows: [], metadata: {} }); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const runner = new Runner(mockDbAdapterInstance, NEW_TEST_GRAPH, { bigquery: { actionRetryLimit: 1 } @@ -1119,8 +1113,6 @@ suite("@dataform/api", () => { .thenResolve({ rows: [], metadata: {} }); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const runner = new Runner(mockDbAdapterInstance, NEW_TEST_GRAPH, { bigquery: { actionRetryLimit: 2 } @@ -1183,8 +1175,6 @@ suite("@dataform/api", () => { .thenResolve({ rows: [], metadata: {} }); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const runner = new Runner(mockDbAdapterInstance, NEW_TEST_GRAPH_WITH_OPERATION, { bigquery: { actionRetryLimit: 3 } @@ -1235,7 +1225,6 @@ suite("@dataform/api", () => { reject(new Error("Run cancelled")); }); }), - withClientLock: callback => callback(mockDbAdapter), schemas: _ => Promise.resolve([]), createSchema: (_, __) => Promise.resolve(), table: _ => undefined @@ -1292,8 +1281,6 @@ suite("@dataform/api", () => { }); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const labels = { env: "testing", team: "dataform" }; const runner = new Runner(mockDbAdapterInstance, NEW_TEST_GRAPH, { @@ -1355,8 +1342,6 @@ suite("@dataform/api", () => { }); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const globalLabels = { env: "testing", team: "dataform" }; const runner = new Runner(mockDbAdapterInstance, NEW_TEST_GRAPH, { @@ -1434,8 +1419,6 @@ suite("@dataform/api", () => { ); const mockDbAdapterInstance = instance(mockedDbAdapter); - mockDbAdapterInstance.withClientLock = async callback => - await callback(mockDbAdapterInstance); const runner = new Runner(mockDbAdapterInstance, METADATA_TEST_GRAPH); diff --git a/tests/api/projects.spec.ts b/tests/api/projects.spec.ts index 82c62f802..f70bc85ab 100644 --- a/tests/api/projects.spec.ts +++ b/tests/api/projects.spec.ts @@ -694,10 +694,13 @@ suite("examples", () => { test("times out after timeout period during compilation", async () => { try { - await compile({ projectDir: "tests/api/projects/never_finishes_compiling" }); + await compile({ + projectDir: "tests/api/projects/never_finishes_compiling", + timeoutMillis: 1000 + }); fail("Compilation timeout Error expected."); } catch (e) { - expect(e.message).to.equal("Compilation timed out"); + expect(e.message).to.equal("Worker timed out after 1 seconds"); } }); diff --git a/tests/integration/BUILD b/tests/integration/BUILD index 5f1ee674c..fdccfff82 100644 --- a/tests/integration/BUILD +++ b/tests/integration/BUILD @@ -12,6 +12,8 @@ ts_test_suite( "//test_credentials:bigquery.json", "//tests/integration/bigquery_project:files", "//tests/integration/bigquery_project:node_modules", + "//packages/@dataform/core:bundle.js", + "//packages/@dataform/core:package.json", ], tags = ["integration"], deps = [ @@ -22,8 +24,10 @@ ts_test_suite( "//protos:ts", "//testing", "@npm//@types/chai", + "@npm//@types/fs-extra", "@npm//@types/long", "@npm//@types/node", "@npm//chai", + "@npm//fs-extra", ], ) diff --git a/tests/integration/bigquery.spec.ts b/tests/integration/bigquery.spec.ts index 01b03ce23..de5c97c65 100644 --- a/tests/integration/bigquery.spec.ts +++ b/tests/integration/bigquery.spec.ts @@ -1,5 +1,7 @@ import { expect } from "chai"; +import * as fs from "fs-extra"; import Long from "long"; +import * as path from "path"; import * as dfapi from "df/cli/api"; import * as dbadapters from "df/cli/api/dbadapters"; @@ -486,6 +488,129 @@ suite("@dataform/integration/bigquery", { parallel: true }, ({ before, after }) expect(partialSearch.length).equals(2); expect(columnSearch.length).greaterThan(0); }); + + test("JiT execution e2e", { timeout: 120000 }, async () => { + // Create a simple project with a JiT table + const projectDir = "tests/integration/jit_project"; + if (fs.existsSync(projectDir)) { + fs.removeSync(projectDir); + } + fs.mkdirpSync(path.join(projectDir, "definitions")); + fs.writeFileSync( + path.join(projectDir, "workflow_settings.yaml"), + ` +defaultProject: dataform-open-source +defaultLocation: US +defaultDataset: df_integration_test_jit +` + ); + fs.writeFileSync( + path.join(projectDir, "definitions/jit_table.js"), + `publish("jit_table", { type: "table" }).jitCode(async (jctx) => "SELECT 1 as id")` + ); + + // Mock @dataform/core to avoid npm install and 403 error + const nodeModulesDir = path.join(projectDir, "node_modules", "@dataform", "core"); + fs.mkdirpSync(nodeModulesDir); + const coreBundlePath = path.resolve("packages/@dataform/core/bundle.js"); + fs.copyFileSync(coreBundlePath, path.join(nodeModulesDir, "bundle.js")); + const corePackageJsonPath = path.resolve("packages/@dataform/core/package.json"); + fs.copyFileSync(corePackageJsonPath, path.join(nodeModulesDir, "package.json")); + // We also need a package.json in the project root to bypass the dataformCoreVersion check + fs.writeFileSync( + path.join(projectDir, "package.json"), + JSON.stringify({ dependencies: { "@dataform/core": "3.0.0-alpha.0" } }) + ); + + try { + const compiledGraph = await dfapi.compile({ projectDir }); + + // Drop dataset to start fresh + await dbadapter.execute( + "drop schema if exists `dataform-open-source.df_integration_test_jit` cascade" + ); + + const executionGraph = await dfapi.build(compiledGraph, {}, dbadapter); + const runResult = await dfapi.run(dbadapter, executionGraph, { projectDir }).result(); + + expect(dataform.RunResult.ExecutionStatus[runResult.status]).eql( + dataform.RunResult.ExecutionStatus[dataform.RunResult.ExecutionStatus.SUCCESSFUL] + ); + + const rows = await dbadapter.execute("SELECT * FROM `dataform-open-source.df_integration_test_jit.jit_table`").then(res => res.rows); + expect(rows).to.eql([{ id: 1 }]); + } finally { + if (fs.existsSync(projectDir)) { + fs.removeSync(projectDir); + } + } + }); + + test("JiT dry run integration", { timeout: 120000 }, async () => { + // Create a simple project with a JiT table + const projectDir = "tests/integration/jit_dry_run_project"; + if (fs.existsSync(projectDir)) { + fs.removeSync(projectDir); + } + fs.mkdirpSync(path.join(projectDir, "definitions")); + fs.writeFileSync( + path.join(projectDir, "workflow_settings.yaml"), + ` +defaultProject: dataform-open-source +defaultLocation: US +defaultDataset: df_integration_test_jit_dry_run +` + ); + fs.writeFileSync( + path.join(projectDir, "definitions/jit_table.js"), + `publish("jit_table_dry_run", { type: "table" }).jitCode(async (jctx) => "SELECT 1 as id")` + ); + + // Mock @dataform/core + const nodeModulesDir = path.join(projectDir, "node_modules", "@dataform", "core"); + fs.mkdirpSync(nodeModulesDir); + const coreBundlePath = path.resolve("packages/@dataform/core/bundle.js"); + fs.copyFileSync(coreBundlePath, path.join(nodeModulesDir, "bundle.js")); + const corePackageJsonPath = path.resolve("packages/@dataform/core/package.json"); + fs.copyFileSync(corePackageJsonPath, path.join(nodeModulesDir, "package.json")); + fs.writeFileSync( + path.join(projectDir, "package.json"), + JSON.stringify({ dependencies: { "@dataform/core": "3.0.0-alpha.0" } }) + ); + + try { + const compiledGraph = await dfapi.compile({ projectDir }); + + // Drop dataset to start fresh + await dbadapter.execute( + "drop schema if exists `dataform-open-source.df_integration_test_jit_dry_run` cascade" + ); + + const executionGraph = await dfapi.build(compiledGraph, {}, dbadapter); + + const runResult = await dfapi.run(dbadapter, executionGraph, { + projectDir, + bigquery: { dryRun: true } + }).result(); + + expect(dataform.RunResult.ExecutionStatus[runResult.status]).eql( + dataform.RunResult.ExecutionStatus[dataform.RunResult.ExecutionStatus.SUCCESSFUL] + ); + + // Verify that the table was NOT created + const tables = await dbadapter.schemas("dataform-open-source").then(schemas => { + if (!schemas.includes("df_integration_test_jit_dry_run")) { + return []; + } + return dbadapter.tables("dataform-open-source", "df_integration_test_jit_dry_run"); + }); + expect(tables.length).to.equal(0); + } finally { + if (fs.existsSync(projectDir)) { + fs.removeSync(projectDir); + } + } + }); }); async function cleanWarehouse(